diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml
index 992496b69..4308e4201 100644
--- a/.github/workflows/linting.yml
+++ b/.github/workflows/linting.yml
@@ -3,7 +3,7 @@ name: Linting
on: [push, pull_request]
jobs:
- pylint:
+ ruff-linting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
@@ -11,19 +11,15 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- - name: Install pylint
+ - name: Install ruff
run: |
python -m pip install --upgrade pip
- pip install pylint==2.16.1
- - name: Run pylint
+ pip install ruff==0.12.0
+ - name: Run ruff linter
run: |
- pylint algoperf
- pylint reference_algorithms
- pylint prize_qualification_baselines
- pylint submission_runner.py
- pylint tests
+ ruff check
- isort:
+ ruff-formatter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
@@ -31,26 +27,11 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- - name: Install isort
+ - name: Install ruff
run: |
python -m pip install --upgrade pip
- pip install isort==5.12.0
- - name: Run isort
+ pip install ruff==0.12.0
+ - name: Run ruff formatter
run: |
- isort . --check --diff
+ ruff format --check
- yapf:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python 3.11.10
- uses: actions/setup-python@v2
- with:
- python-version: 3.11.10
- - name: Install yapf
- run: |
- python -m pip install --upgrade pip
- pip install yapf==0.32 toml
- - name: Run yapf
- run: |
- yapf . --diff --recursive
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cc8f13d25..f2d684c53 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,14 +1,10 @@
repos:
- - repo: https://github.com/google/yapf
- rev: v0.32.0
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.12.0
hooks:
- - id: yapf
- args: ["--in-place", "--parallel", "--verbose", "--recursive"]
- - repo: https://github.com/pycqa/isort
- rev: 5.10.1
- hooks:
- - id: isort
- - repo: https://github.com/pycqa/pylint
- rev: v2.16.1
- hooks:
- - id: pylint
+ # Run the linter (don't change files).
+ - id: ruff-check
+ # Run the formatter (don't change files).
+ - id: ruff-format
+ args: ["--check"]
diff --git a/README.md b/README.md
index 8e470266d..0666d21d5 100644
--- a/README.md
+++ b/README.md
@@ -14,11 +14,11 @@
Benchmark/Results Paper
-[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml)
-[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml)
-[](https://github.com/mlcommons/algorithmic-efficiency/blob/main/LICENSE.md)
-[](https://github.com/google/yapf)
-[](https://discord.gg/5FPXK7SMt6)
+[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml)
+[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml)
+[](https://github.com/astral-sh/ruff)
+[](LICENSE.md)
+[](https://discord.gg/5FPXK7SMt6)
---
@@ -28,11 +28,12 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock
---
-> This is the repository for the *AlgoPerf: Training Algorithms benchmark* measuring neural network training speedups due to algorithmic improvements.
+> This is the repository for the _AlgoPerf: Training Algorithms benchmark_ measuring neural network training speedups due to algorithmic improvements.
> It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/).
> This repository holds the benchmark code, the benchmark's [**technical documentation**](/docs/DOCUMENTATION.md) and [**getting started guides**](/docs/GETTING_STARTED.md). For a detailed description of the benchmark design, see our [**introductory paper**](https://arxiv.org/abs/2306.07179), for the results of the inaugural competition see our [**results paper**](https://openreview.net/forum?id=CtM5xjRSfm).
>
> **See our [AlgoPerf Leaderboard](https://github.com/mlcommons/submissions_algorithms) for the latest results of the benchmark and to submit your algorithm.**
+
---
> [!IMPORTANT]
@@ -50,14 +51,13 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock
## Installation
-> [!TIP]
-> **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/).
+> [!TIP] > **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/).
You can install this package and dependencies in a [Python virtual environment](/docs/GETTING_STARTED.md#python-virtual-environment) or use a [Docker/Singularity/Apptainer container](/docs/GETTING_STARTED.md#docker) (recommended).
We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments.
Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document.
-*TL;DR to install the Jax version for GPU run:*
+_TL;DR to install the Jax version for GPU run:_
```bash
pip3 install -e '.[pytorch_cpu]'
@@ -65,7 +65,7 @@ pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax
pip3 install -e '.[full]'
```
-*TL;DR to install the PyTorch version for GPU run:*
+_TL;DR to install the PyTorch version for GPU run:_
```bash
pip3 install -e '.[jax_cpu]'
@@ -77,7 +77,7 @@ pip3 install -e '.[full]'
For detailed instructions on developing your own algorithm in the benchmark see the [Getting Started](/docs/GETTING_STARTED.md) document.
-*TL;DR running a JAX workload:*
+_TL;DR running a JAX workload:_
```bash
python3 submission_runner.py \
@@ -89,7 +89,7 @@ python3 submission_runner.py \
--tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json
```
-*TL;DR running a PyTorch workload:*
+_TL;DR running a PyTorch workload:_
```bash
python3 submission_runner.py \
@@ -117,17 +117,15 @@ Our [**Contributing**](/docs/CONTRIBUTING.md) document provides further MLCommon
## License
-The *AlgoPerf* codebase is licensed under the [Apache License 2.0](/LICENSE.md).
+The _AlgoPerf_ codebase is licensed under the [Apache License 2.0](/LICENSE.md).
## Paper and Citing the AlgoPerf Benchmark
-In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the *AlgoPerf: Training Algorithms* benchmark.
+In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the _AlgoPerf: Training Algorithms_ benchmark.
-If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads, please consider citing our paper:
+If you are using the _AlgoPerf benchmark_, its codebase, baselines, or workloads, please consider citing our paper:
-> [Dahl, Schneider, Nado, et al.
-> **Benchmarking Neural Network Training Algorithms**
-> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179)
+> [Dahl, Schneider, Nado, et al.
> **Benchmarking Neural Network Training Algorithms**
> _arXiv 2306.07179_](http://arxiv.org/abs/2306.07179)
```bibtex
@Misc{Dahl2023AlgoPerf,
@@ -139,10 +137,9 @@ If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads
}
```
-If you use the results from the first *AlgoPerf competition*, please consider citing the results paper, as well as the relevant submissions:
+If you use the results from the first _AlgoPerf competition_, please consider citing the results paper, as well as the relevant submissions:
-> [Kasimbeg, Schneider, Eschenhagen, et al.
-> **Accelerating neural network training: An analysis of the AlgoPerf competition**
+> [Kasimbeg, Schneider, Eschenhagen, et al.
> **Accelerating neural network training: An analysis of the AlgoPerf competition**
> ICLR 2025](https://openreview.net/forum?id=CtM5xjRSfm)
```bibtex
diff --git a/algoperf/__init__.py b/algoperf/__init__.py
index 7d54f8290..5ecee05af 100644
--- a/algoperf/__init__.py
+++ b/algoperf/__init__.py
@@ -2,4 +2,4 @@
from ._version import version as __version__
-__all__ = ["__version__"]
+__all__ = ['__version__']
diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py
index f4cb6c2db..f8cc40599 100644
--- a/algoperf/checkpoint_utils.py
+++ b/algoperf/checkpoint_utils.py
@@ -7,37 +7,41 @@
import os
from typing import Sequence, Tuple
+import jax
+import numpy as np
+import torch
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
-import jax
-import numpy as np
from tensorflow.io import gfile # pytype: disable=import-error
-import torch
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
_, _, DEVICE, _ = pytorch_setup()
-CheckpointReturn = Tuple[spec.OptimizerState,
- spec.ParameterContainer,
- spec.ModelAuxiliaryState,
- dict,
- list,
- int,
- int]
-
-
-def maybe_restore_checkpoint(framework: str,
- optimizer_state: spec.OptimizerState,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- train_state: dict,
- eval_results: list,
- global_step: int,
- preemption_count: int,
- checkpoint_dir: str) -> CheckpointReturn:
+CheckpointReturn = Tuple[
+ spec.OptimizerState,
+ spec.ParameterContainer,
+ spec.ModelAuxiliaryState,
+ dict,
+ list,
+ int,
+ int,
+]
+
+
+def maybe_restore_checkpoint(
+ framework: str,
+ optimizer_state: spec.OptimizerState,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ train_state: dict,
+ eval_results: list,
+ global_step: int,
+ preemption_count: int,
+ checkpoint_dir: str,
+) -> CheckpointReturn:
"""Optionally restores from a checkpoint.
The checkpoint logic is as follows: if there is a checkpoint in
@@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str,
uninitialized_global_step = -1
uninitialized_preemption_count = -1
checkpoint_state = {
- 'model_params': model_params,
- 'optimizer_state': opt_state,
- 'model_state': model_state,
- 'train_state': train_state,
- 'eval_results': None,
- 'global_step': uninitialized_global_step,
- 'preemption_count': uninitialized_preemption_count,
+ 'model_params': model_params,
+ 'optimizer_state': opt_state,
+ 'model_state': model_state,
+ 'train_state': train_state,
+ 'eval_results': None,
+ 'global_step': uninitialized_global_step,
+ 'preemption_count': uninitialized_preemption_count,
}
if framework == 'jax':
latest_ckpt = flax_checkpoints.restore_checkpoint(
- checkpoint_dir, target=checkpoint_state)
- save_path = os.path.join(checkpoint_dir,
- 'checkpoint_' + str(latest_ckpt['global_step']))
+ checkpoint_dir, target=checkpoint_state
+ )
+ save_path = os.path.join(
+ checkpoint_dir, 'checkpoint_' + str(latest_ckpt['global_step'])
+ )
else:
latest_ckpt = checkpoint_state
save_path = latest_checkpoint(checkpoint_dir)
@@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str,
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step
if not found_checkpoint:
- return (optimizer_state,
- model_params,
- model_state,
- train_state,
- eval_results,
- global_step,
- preemption_count)
+ return (
+ optimizer_state,
+ model_params,
+ model_state,
+ train_state,
+ eval_results,
+ global_step,
+ preemption_count,
+ )
# If there's the latest checkpoint in the checkpoint_dir, restore from that.
if framework == 'jax':
checkpoint_state = replicate_checkpoint(
- latest_ckpt,
- pytree_keys=[
- 'optimizer_state',
- 'model_params',
- 'model_state',
- ])
- checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'],
- opt_update_fn)
+ latest_ckpt,
+ pytree_keys=[
+ 'optimizer_state',
+ 'model_params',
+ 'model_state',
+ ],
+ )
+ checkpoint_state['optimizer_state'] = (
+ checkpoint_state['optimizer_state'],
+ opt_update_fn,
+ )
checkpoint_state['eval_results'] = [
- (value, key) for key, value in latest_ckpt['eval_results'].items()
+ (value, key) for key, value in latest_ckpt['eval_results'].items()
]
else:
checkpoint_state = latest_ckpt
if isinstance(
- model_params,
- (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
+ model_params,
+ (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
+ ):
model_params = model_params.module
model_params.load_state_dict(checkpoint_state['model_params'])
checkpoint_state['model_params'] = model_params
for key in optimizer_state.keys():
optimizer_state[key].load_state_dict(
- checkpoint_state['optimizer_state'][key])
+ checkpoint_state['optimizer_state'][key]
+ )
checkpoint_state['optimizer_state'][key] = optimizer_state[key]
logging.info(f'Loaded checkpoint from {save_path}.')
- return (checkpoint_state['optimizer_state'],
- checkpoint_state['model_params'],
- checkpoint_state['model_state'],
- checkpoint_state['train_state'],
- list(checkpoint_state['eval_results']),
- checkpoint_state['global_step'],
- checkpoint_state['preemption_count'] + 1)
-
-
-def replicate_checkpoint(latest: dict,
- pytree_keys: Sequence[str],
- replicate: bool = True) -> dict:
+ return (
+ checkpoint_state['optimizer_state'],
+ checkpoint_state['model_params'],
+ checkpoint_state['model_state'],
+ checkpoint_state['train_state'],
+ list(checkpoint_state['eval_results']),
+ checkpoint_state['global_step'],
+ checkpoint_state['preemption_count'] + 1,
+ )
+
+
+def replicate_checkpoint(
+ latest: dict, pytree_keys: Sequence[str], replicate: bool = True
+) -> dict:
"""Restores from the provided checkpoint.
Args:
@@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict,
return pytree
-def save_checkpoint(framework: str,
- optimizer_state: spec.OptimizerState,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- train_state: dict,
- eval_results: list,
- global_step: int,
- preemption_count: int,
- checkpoint_dir: str,
- save_intermediate_checkpoints: bool) -> None:
+def save_checkpoint(
+ framework: str,
+ optimizer_state: spec.OptimizerState,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ train_state: dict,
+ eval_results: list,
+ global_step: int,
+ preemption_count: int,
+ checkpoint_dir: str,
+ save_intermediate_checkpoints: bool,
+) -> None:
"""Save the checkpoint in `checkpoint_dir`.
Args:
@@ -199,8 +216,9 @@ def save_checkpoint(framework: str,
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
- model_params,
- (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
+ model_params,
+ (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
+ ):
model_params = model_params.module
model_params = model_params.state_dict()
optimizer_state_dict = {}
@@ -209,33 +227,36 @@ def save_checkpoint(framework: str,
optimizer_state_dict[key] = optimizer_state[key].state_dict()
else:
logging.warning(
- f'The optimizer state for key {key} is not saved, because '
- f'{type(optimizer_state[key])} has not implemented a state_dict() '
- 'method.')
+ f'The optimizer state for key {key} is not saved, because '
+ f'{type(optimizer_state[key])} has not implemented a state_dict() '
+ 'method.'
+ )
opt_state = optimizer_state_dict
checkpoint_state = {
- 'model_params': model_params,
- 'optimizer_state': opt_state,
- 'model_state': model_state,
- 'train_state': train_state,
- 'eval_results': tuple(eval_results),
- 'global_step': global_step,
- 'preemption_count': preemption_count,
+ 'model_params': model_params,
+ 'optimizer_state': opt_state,
+ 'model_state': model_state,
+ 'train_state': train_state,
+ 'eval_results': tuple(eval_results),
+ 'global_step': global_step,
+ 'preemption_count': preemption_count,
}
save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}')
if framework == 'jax':
flax_checkpoints.save_checkpoint(
- checkpoint_dir,
- target=checkpoint_state,
- step=global_step,
- overwrite=True,
- keep=np.inf if save_intermediate_checkpoints else 1)
+ checkpoint_dir,
+ target=checkpoint_state,
+ step=global_step,
+ overwrite=True,
+ keep=np.inf if save_intermediate_checkpoints else 1,
+ )
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
- os.path.join(checkpoint_dir, 'checkpoint_*'))
+ os.path.join(checkpoint_dir, 'checkpoint_*')
+ )
for path in checkpoint_files:
logging.info('Removing checkpoint at %s', path)
gfile.rmtree(path)
diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py
index 37d1bd20f..f08d9d2db 100644
--- a/algoperf/data_utils.py
+++ b/algoperf/data_utils.py
@@ -7,17 +7,16 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
-from torch.utils.data import DataLoader
-from torch.utils.data import DistributedSampler
-from torch.utils.data import Sampler
+from torch.utils.data import DataLoader, DistributedSampler, Sampler
from algoperf import spec
def shard_and_maybe_pad_np(
- batch: Dict[str, spec.Tensor],
- padding_value: int = 0,
- global_batch_size: Optional[int] = None) -> Dict[str, spec.Tensor]:
+ batch: Dict[str, spec.Tensor],
+ padding_value: int = 0,
+ global_batch_size: Optional[int] = None,
+) -> Dict[str, spec.Tensor]:
"""Prepare tf data for JAX or PyTorch DDP.
Convert an input batch from tf Tensors to numpy arrays, pad it with
@@ -26,11 +25,13 @@ def shard_and_maybe_pad_np(
"""
local_device_count = max(torch.cuda.device_count(), jax.local_device_count())
inputs = batch['inputs']
- current_batch_size = inputs[0].shape[0] if isinstance(
- inputs, tuple) else inputs.shape[0]
+ current_batch_size = (
+ inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0]
+ )
if global_batch_size is not None:
- assert global_batch_size >= current_batch_size, \
- 'global_batch_size must be larger than or equal to current_batch_size.'
+ assert global_batch_size >= current_batch_size, (
+ 'global_batch_size must be larger than or equal to current_batch_size.'
+ )
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
@@ -43,7 +44,8 @@ def shard_and_maybe_pad_np(
pad_size = local_device_count - remainder_size
targets = batch['targets']
targets_shape = tuple(
- targets[0].shape if isinstance(targets, tuple) else targets.shape)
+ targets[0].shape if isinstance(targets, tuple) else targets.shape
+ )
# We need a 2d mask for WMT.
mask_shape = targets_shape if len(targets_shape) < 3 else targets_shape[0]
# Get weights from batch if there are any.
@@ -68,9 +70,9 @@ def _prepare(x):
return jax.tree.map(_prepare, batch)
-def pad(tensor: np.ndarray,
- pad_size: int,
- padding_value: int = 0) -> np.ndarray:
+def pad(
+ tensor: np.ndarray, pad_size: int, padding_value: int = 0
+) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
@@ -78,8 +80,9 @@ def pad(tensor: np.ndarray,
return padded_tensor
-def mixup_pytorch(batch: Tuple[spec.Tensor, spec.Tensor],
- alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]:
+def mixup_pytorch(
+ batch: Tuple[spec.Tensor, spec.Tensor], alpha: float = 0.2
+) -> Tuple[spec.Tensor, spec.Tensor]:
inputs, targets = batch
# Transform to one-hot targets.
targets = F.one_hot(targets, num_classes=1000)
@@ -144,12 +147,14 @@ class DistributedEvalSampler(Sampler):
... train(loader)
"""
- def __init__(self,
- dataset: torch.utils.data.Dataset,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None,
- shuffle: bool = False,
- seed: int = 0) -> None:
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = False,
+ seed: int = 0,
+ ) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError('Requires distributed package to be available.')
@@ -165,7 +170,7 @@ def __init__(self,
# true value without extra samples
self.total_size = len(self.dataset)
indices = list(range(self.total_size))
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
# true value without extra samples
self.num_samples = len(indices)
@@ -182,7 +187,7 @@ def __iter__(self) -> Iterable[int]:
indices = list(range(len(self.dataset)))
# Subsample.
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
@@ -203,11 +208,13 @@ def set_epoch(self, epoch: int) -> None:
# Modified from github.com/pytorch/pytorch/issues/23900#issuecomment-518858050.
-def cycle(iterable: Iterable,
- keys: Tuple[str, ...] = ('inputs', 'targets'),
- custom_sampler: bool = False,
- use_mixup: bool = False,
- mixup_alpha: float = 0.2) -> Iterable:
+def cycle(
+ iterable: Iterable,
+ keys: Tuple[str, ...] = ('inputs', 'targets'),
+ custom_sampler: bool = False,
+ use_mixup: bool = False,
+ mixup_alpha: float = 0.2,
+) -> Iterable:
iterator = iter(iterable)
epoch = 0
while True:
@@ -229,11 +236,9 @@ def cycle(iterable: Iterable,
# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
# ConvNets/image_classification/dataloaders.py
class PrefetchedWrapper:
-
- def __init__(self,
- dataloader: DataLoader,
- device: torch.device,
- start_epoch: int = 0) -> None:
+ def __init__(
+ self, dataloader: DataLoader, device: torch.device, start_epoch: int = 0
+ ) -> None:
self.dataloader = dataloader
self.epoch = start_epoch
self.device = device
@@ -254,11 +259,12 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]:
for next_inputs, next_targets in self.dataloader:
with torch.cuda.stream(stream):
next_inputs = next_inputs.to(
- self.device, dtype=torch.float, non_blocking=True)
+ self.device, dtype=torch.float, non_blocking=True
+ )
next_targets = next_targets.to(self.device, non_blocking=True)
if not first:
- yield inputs, targets
+ yield inputs, targets # noqa: F821
else:
first = False
diff --git a/algoperf/halton.py b/algoperf/halton.py
index 1f36b07bf..08c5466f1 100644
--- a/algoperf/halton.py
+++ b/algoperf/halton.py
@@ -36,10 +36,12 @@ def _is_prime(n: int) -> bool:
return all(n % i != 0 for i in range(2, int(n**0.5) + 1)) and n != 2
-def _generate_dim(num_samples: int,
- base: int,
- per_dim_shift: bool,
- shuffled_seed_sequence: List[int]) -> List[float]:
+def _generate_dim(
+ num_samples: int,
+ base: int,
+ per_dim_shift: bool,
+ shuffled_seed_sequence: List[int],
+) -> List[float]:
"""Generate `num_samples` from a Van der Corput sequence with base `base`.
Args:
@@ -59,8 +61,9 @@ def _generate_dim(num_samples: int,
ValueError: if `base` is negative or not prime.
"""
if base < 0 or not _is_prime(base):
- raise ValueError('Each Van der Corput sequence requires a prime `base`, '
- f'received {base}.')
+ raise ValueError(
+ f'Each Van der Corput sequence requires a prime `base`, received {base}.'
+ )
rng = random.RandomState(base)
if shuffled_seed_sequence is None:
@@ -76,7 +79,7 @@ def _generate_dim(num_samples: int,
dim_sequence = []
for i in range(1, num_samples + 1):
- num = 0.
+ num = 0.0
denominator = base
while i:
num += shuffled_seed_sequence[i % base] / denominator
@@ -91,13 +94,15 @@ def _generate_dim(num_samples: int,
Matrix = List[List[int]]
-def generate_sequence(num_samples: int,
- num_dims: int,
- skip: int = 100,
- per_dim_shift: bool = True,
- shuffle_sequence: bool = True,
- primes: Sequence[int] = None,
- shuffled_seed_sequence: Matrix = None) -> Matrix:
+def generate_sequence(
+ num_samples: int,
+ num_dims: int,
+ skip: int = 100,
+ per_dim_shift: bool = True,
+ shuffle_sequence: bool = True,
+ primes: Sequence[int] = None,
+ shuffled_seed_sequence: Matrix = None,
+) -> Matrix:
"""Generate `num_samples` from a Halton sequence of dimension `num_dims`.
Each dimension is generated independently from a shuffled Van der Corput
@@ -140,25 +145,29 @@ def generate_sequence(num_samples: int,
if primes is not None and len(primes) != num_dims:
raise ValueError(
- 'If passing in a sequence of primes it must be the same length as '
- f'num_dims={num_dims}, received {primes} (len {len(primes)}).')
+ 'If passing in a sequence of primes it must be the same length as '
+ f'num_dims={num_dims}, received {primes} (len {len(primes)}).'
+ )
if shuffled_seed_sequence is not None:
if len(shuffled_seed_sequence) != num_dims:
raise ValueError(
- 'If passing in `shuffled_seed_sequence` it must be the same length '
- f'as num_dims={num_dims}, received {shuffled_seed_sequence} '
- f'(len {len(shuffled_seed_sequence)}).')
+ 'If passing in `shuffled_seed_sequence` it must be the same length '
+ f'as num_dims={num_dims}, received {shuffled_seed_sequence} '
+ f'(len {len(shuffled_seed_sequence)}).'
+ )
for d in range(num_dims):
if len(shuffled_seed_sequence[d]) != primes[d]:
raise ValueError(
- 'If passing in `shuffled_seed_sequence` it must have element `{d}` '
- 'be a sequence of length `primes[{d}]`={expected}, received '
- '{actual} (len {length})'.format(
- d=d,
- expected=primes[d],
- actual=shuffled_seed_sequence[d],
- length=shuffled_seed_sequence[d]))
+ 'If passing in `shuffled_seed_sequence` it must have element `{d}` '
+ 'be a sequence of length `primes[{d}]`={expected}, received '
+ '{actual} (len {length})'.format(
+ d=d,
+ expected=primes[d],
+ actual=shuffled_seed_sequence[d],
+ length=shuffled_seed_sequence[d],
+ )
+ )
if primes is None:
primes = []
@@ -166,7 +175,7 @@ def generate_sequence(num_samples: int,
while len(primes) < num_dims + 1:
primes = generate_primes(1000 * prime_attempts)
prime_attempts += 1
- primes = primes[-num_dims - 1:-1]
+ primes = primes[-num_dims - 1 : -1]
# Skip the first `skip` points in the sequence because they can have unwanted
# correlations.
@@ -179,10 +188,11 @@ def generate_sequence(num_samples: int,
else:
dim_shuffled_seed_sequence = shuffled_seed_sequence[d]
dim_sequence = _generate_dim(
- num_samples=num_samples,
- base=primes[d],
- shuffled_seed_sequence=dim_shuffled_seed_sequence,
- per_dim_shift=per_dim_shift)
+ num_samples=num_samples,
+ base=primes[d],
+ shuffled_seed_sequence=dim_shuffled_seed_sequence,
+ per_dim_shift=per_dim_shift,
+ )
dim_sequence = dim_sequence[skip:]
halton_sequence.append(dim_sequence)
@@ -195,29 +205,29 @@ def generate_sequence(num_samples: int,
return halton_sequence
-def _generate_double_point(name: str,
- min_val: float,
- max_val: float,
- scaling: str,
- halton_point: float) -> Tuple[str, float]:
+def _generate_double_point(
+ name: str, min_val: float, max_val: float, scaling: str, halton_point: float
+) -> Tuple[str, float]:
"""Generate a float hyperparameter value from a Halton sequence point."""
if scaling not in ['linear', 'log']:
raise ValueError(
- 'Only log or linear scaling is supported for floating point '
- f'parameters. Received {scaling}.')
+ 'Only log or linear scaling is supported for floating point '
+ f'parameters. Received {scaling}.'
+ )
if scaling == 'log':
# To transform from [0, 1] to [min_val, max_val] on a log scale we do:
# min_val * exp(x * log(max_val / min_val)).
- rescaled_value = (
- min_val * math.exp(halton_point * math.log(max_val / min_val)))
+ rescaled_value = min_val * math.exp(
+ halton_point * math.log(max_val / min_val)
+ )
else:
rescaled_value = halton_point * (max_val - min_val) + min_val
return name, rescaled_value
-def _generate_discrete_point(name: str,
- feasible_points: Sequence[Any],
- halton_point: float) -> Any:
+def _generate_discrete_point(
+ name: str, feasible_points: Sequence[Any], halton_point: float
+) -> Any:
"""Generate a discrete hyperparameter value from a Halton sequence point."""
index = int(math.floor(halton_point * len(feasible_points)))
return name, feasible_points[index]
@@ -236,27 +246,23 @@ def interval(start: int, end: int) -> Tuple[int, int]:
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
min_val, max_val = range_endpoints
- return functools.partial(_generate_double_point,
- name,
- min_val,
- max_val,
- 'log')
+ return functools.partial(
+ _generate_double_point, name, min_val, max_val, 'log'
+ )
def uniform(
- name: str, search_points: Union[_DiscretePoints,
- Tuple[int, int]]) -> _GeneratorFn:
+ name: str, search_points: Union[_DiscretePoints, Tuple[int, int]]
+) -> _GeneratorFn:
if isinstance(search_points, _DiscretePoints):
- return functools.partial(_generate_discrete_point,
- name,
- search_points.feasible_points)
+ return functools.partial(
+ _generate_discrete_point, name, search_points.feasible_points
+ )
min_val, max_val = search_points
- return functools.partial(_generate_double_point,
- name,
- min_val,
- max_val,
- 'linear')
+ return functools.partial(
+ _generate_double_point, name, min_val, max_val, 'linear'
+ )
def product(sweeps: Sequence[_SweepSequence]) -> _SweepSequence:
@@ -277,9 +283,10 @@ def sweep(name, feasible_points: Sequence[Any]) -> _SweepSequence:
return [{name: x} for x in feasible_points.feasible_points]
-def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn,
- _SweepSequence]],
- length: int) -> _SweepSequence:
+def zipit(
+ generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, _SweepSequence]],
+ length: int,
+) -> _SweepSequence:
"""Zip together a list of hyperparameter generators.
Args:
@@ -302,7 +309,8 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn,
hyperparameter name from generator_fns_or_sweeps.
"""
halton_sequence = generate_sequence(
- num_samples=length, num_dims=len(generator_fns_or_sweeps))
+ num_samples=length, num_dims=len(generator_fns_or_sweeps)
+ )
# A List[Dict] of hyperparameter names to sweep values.
hyperparameter_sweep = []
for trial_index in range(length):
@@ -326,8 +334,9 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn,
_ListSearchSpace = List[Dict[str, Union[str, float, Sequence]]]
-def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace],
- num_trials: int) -> List[collections.namedtuple]:
+def generate_search(
+ search_space: Union[_DictSearchSpace, _ListSearchSpace], num_trials: int
+) -> List[collections.namedtuple]:
"""Generate a random search with the given bounds and scaling.
Args:linear
@@ -352,8 +361,9 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace],
else:
raise AttributeError('tuning_search_space should either be a dict or list.')
- named_tuple_class = collections.namedtuple('Hyperparameters',
- all_hyperparameter_names)
+ named_tuple_class = collections.namedtuple(
+ 'Hyperparameters', all_hyperparameter_names
+ )
if isinstance(search_space, dict):
hyperparameter_generators = []
@@ -367,16 +377,18 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace],
generator_fn = uniform(name, interval(space['min'], space['max']))
hyperparameter_generators.append(generator_fn)
return [
- named_tuple_class(**p)
- for p in zipit(hyperparameter_generators, num_trials)
+ named_tuple_class(**p)
+ for p in zipit(hyperparameter_generators, num_trials)
]
else:
hyperparameters = []
updated_num_trials = min(num_trials, len(search_space))
if num_trials != len(search_space):
- logging.info(f'--num_tuning_trials was set to {num_trials}, but '
- f'{len(search_space)} trial(s) found in the JSON file. '
- f'Updating --num_tuning_trials to {updated_num_trials}.')
+ logging.info(
+ f'--num_tuning_trials was set to {num_trials}, but '
+ f'{len(search_space)} trial(s) found in the JSON file. '
+ f'Updating --num_tuning_trials to {updated_num_trials}.'
+ )
for trial in search_space:
hyperparameters.append(named_tuple_class(**trial))
return hyperparameters[:updated_num_trials]
diff --git a/algoperf/init_utils.py b/algoperf/init_utils.py
index 185480cc7..c66a0be20 100644
--- a/algoperf/init_utils.py
+++ b/algoperf/init_utils.py
@@ -12,7 +12,7 @@
def pytorch_default_init(module: nn.Module) -> None:
# Perform lecun_normal initialization.
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
- std = math.sqrt(1. / fan_in) / .87962566103423978
+ std = math.sqrt(1.0 / fan_in) / 0.87962566103423978
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
if module.bias is not None:
- nn.init.constant_(module.bias, 0.)
+ nn.init.constant_(module.bias, 0.0)
diff --git a/algoperf/interop_utils.py b/algoperf/interop_utils.py
index 0c6535d7a..c30d0cf3b 100644
--- a/algoperf/interop_utils.py
+++ b/algoperf/interop_utils.py
@@ -6,7 +6,8 @@
def jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False) -> spec.Tensor:
return torch.utils.dlpack.from_dlpack(
- jax.dlpack.to_dlpack(x, take_ownership=take_ownership))
+ jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
+ )
def pytorch_to_jax(x: torch.Tensor) -> spec.Tensor:
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
new file mode 100644
index 000000000..dab338328
--- /dev/null
+++ b/algoperf/jax_utils.py
@@ -0,0 +1,129 @@
+from collections.abc import Sequence
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.linen.module import Module, compact, merge_param
+from flax.typing import PRNGKey
+from jax import lax, random
+
+
+# Custom Layers
+class Dropout(Module):
+ # pylint: disable=line-too-long
+ """Create a dropout layer.
+ Forked from
+ https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
+ The reference dropout implementation is modified support changes
+ to dropout rate during training by:
+ 1) adding rate argument to the __call__ method.
+ 2) removing the if-else condition to check for edge cases, which
+ will trigger a recompile for jitted code.
+
+ .. note::
+ When using :meth:`Module.apply() `, make sure
+ to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
+ variable initialization.
+
+ Example usage::
+
+ >>> import flax.linen as nn
+ >>> import jax, jax.numpy as jnp
+
+ >>> class MLP(nn.Module):
+ ... @nn.compact
+ ... def __call__(self, x, train):
+ ... x = nn.Dense(4)(x)
+ ... x = nn.Dropout(0.5, deterministic=not train)(x)
+ ... return x
+
+ >>> model = MLP()
+ >>> x = jnp.ones((1, 3))
+ >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
+ >>> model.apply(variables, x, train=False) # don't use dropout
+ Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
+ >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
+ Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
+
+ Attributes:
+ rate: the dropout probability. (_not_ the keep rate!)
+ broadcast_dims: dimensions that will share the same dropout mask
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
+ and masked, whereas if true, no mask is applied and the inputs are
+ returned as is.
+ rng_collection: the rng collection name to use when requesting an rng
+ key.
+ """
+
+ rate: float | None = None
+ broadcast_dims: Sequence[int] = ()
+ deterministic: bool | None = None
+ rng_collection: str = 'dropout'
+ legacy: bool = False
+
+ @compact
+ def __call__(
+ self,
+ inputs,
+ deterministic: bool | None = None,
+ rate: float | None = None,
+ rng: PRNGKey | None = None,
+ ):
+ """Applies a random dropout mask to the input.
+
+ Args:
+ inputs: the inputs that should be randomly masked.
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
+ and masked, whereas if true, no mask is applied and the inputs are
+ returned as is.
+ rate: the dropout probability. (_not_ the keep rate!)
+ rng: an optional PRNGKey used as the random key, if not specified,
+ one will be generated using ``make_rng`` with the
+ ``rng_collection`` name.
+
+ Returns:
+ The masked inputs reweighted to preserve mean.
+ """
+ deterministic = merge_param(
+ 'deterministic', self.deterministic, deterministic
+ )
+
+ # Override self.rate if rate is passed to __call__
+ if rate is None:
+ rate = self.rate
+
+ if self.legacy:
+ if rate == 0.0:
+ return inputs
+
+ # Prevent gradient NaNs in 1.0 edge-case.
+ if rate == 1.0:
+ return jnp.zeros_like(inputs)
+
+ if deterministic:
+ return inputs
+
+ keep_prob = 1.0 - rate
+ if rng is None:
+ rng = self.make_rng(self.rng_collection)
+ broadcast_shape = list(inputs.shape)
+ for dim in self.broadcast_dims:
+ broadcast_shape[dim] = 1
+ mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
+ mask = jnp.broadcast_to(mask, inputs.shape)
+ return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
+
+
+# Utilities for debugging
+def print_jax_model_summary(model, fake_inputs):
+ """Prints a summary of the jax module."""
+ tabulate_fn = nn.tabulate(
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={
+ 'force_terminal': False,
+ 'force_jupyter': False,
+ 'width': 240,
+ },
+ )
+ print(tabulate_fn(fake_inputs, train=False))
diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py
index c988956dc..17eea74a6 100644
--- a/algoperf/logger_utils.py
+++ b/algoperf/logger_utils.py
@@ -11,12 +11,12 @@
import sys
from typing import Any, Dict, Optional
-from absl import flags
-from clu import metric_writers
import GPUtil
import pandas as pd
import psutil
import torch.distributed as dist
+from absl import flags
+from clu import metric_writers
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -37,12 +37,12 @@ def makedir(dir_name: str, exist_ok: bool = True) -> None:
def get_log_dir(
- experiment_dir: str,
- workload: spec.Workload,
- framework: str,
- experiment_name: str,
- resume_last_run: bool,
- overwrite: bool,
+ experiment_dir: str,
+ workload: spec.Workload,
+ framework: str,
+ experiment_name: str,
+ resume_last_run: bool,
+ overwrite: bool,
) -> Optional[str]:
# Construct path to experiment workload directory.
experiment_dir = os.path.expanduser(experiment_dir)
@@ -50,26 +50,29 @@ def get_log_dir(
if experiment_name is None:
experiment_path = os.path.join(experiment_dir, workload_dir_name)
else:
- experiment_path = os.path.join(experiment_dir,
- experiment_name,
- workload_dir_name)
+ experiment_path = os.path.join(
+ experiment_dir, experiment_name, workload_dir_name
+ )
if os.path.exists(experiment_path):
if overwrite:
logging.info(
- f'Removing existing experiment directory {experiment_path} because '
- '--overwrite was set.')
+ f'Removing existing experiment directory {experiment_path} because '
+ '--overwrite was set.'
+ )
if RANK == 0:
shutil.rmtree(experiment_path)
elif resume_last_run:
logging.info(
- f'Resuming from experiment directory {experiment_path} because '
- '--resume_last_run was set.')
+ f'Resuming from experiment directory {experiment_path} because '
+ '--resume_last_run was set.'
+ )
else:
if RANK == 0:
resume = input(
- 'Found existing experiment dir with the same name: {}. Do you wish '
- 'to resume training from this dir? [y/N]:'.format(experiment_path))
+ 'Found existing experiment dir with the same name: {}. Do you wish '
+ 'to resume training from this dir? [y/N]:'.format(experiment_path)
+ )
if resume.lower() != 'y':
sys.exit()
@@ -83,16 +86,18 @@ def get_log_dir(
return experiment_path
-def write_hparams(hparams: spec.Hyperparameters,
- tuning_dir: str) -> spec.Hyperparameters:
+def write_hparams(
+ hparams: spec.Hyperparameters, tuning_dir: str
+) -> spec.Hyperparameters:
hparams_file_name = os.path.join(tuning_dir, 'hparams.json')
if os.path.exists(hparams_file_name):
# If hparams.json already exist, use the previously saved hyperparameters.
logging.info('Loading hparams from %s.', hparams_file_name)
with open(hparams_file_name, 'r') as f:
hparams_dict = json.load(f)
- hparams = collections.namedtuple('Hyperparameters',
- hparams_dict)(**hparams_dict)
+ hparams = collections.namedtuple('Hyperparameters', hparams_dict)(
+ **hparams_dict
+ )
else:
logging.info('Saving hparams to %s.', hparams_file_name)
if RANK == 0:
@@ -108,8 +113,8 @@ def write_json(name: str, log_dict: Dict, indent: int = 2) -> None:
def write_to_csv(
- metrics: Dict,
- csv_path: str,
+ metrics: Dict,
+ csv_path: str,
) -> None:
try:
with open(csv_path, 'r') as csv_file:
@@ -118,8 +123,10 @@ def write_to_csv(
except (pd.errors.EmptyDataError, FileNotFoundError) as e:
measurements = pd.DataFrame([metrics], columns=sorted(metrics.keys()))
if isinstance(e, pd.errors.EmptyDataError):
- logging.info('Measurements file is empty. Create a new one, starting '
- 'with metrics from this step.')
+ logging.info(
+ 'Measurements file is empty. Create a new one, starting '
+ 'with metrics from this step.'
+ )
with open(csv_path, 'w') as csv_file:
measurements.to_csv(csv_file, index=False)
return
@@ -130,7 +137,8 @@ def _get_utilization() -> Dict:
# CPU
util_data['cpu.util.avg_percent_since_last'] = psutil.cpu_percent(
- interval=None) # non-blocking (cpu util percentage since last call)
+ interval=None
+ ) # non-blocking (cpu util percentage since last call)
util_data['cpu.freq.current'] = psutil.cpu_freq().current
# Memory
@@ -190,7 +198,7 @@ def _get_system_hardware_info() -> Dict:
try:
system_hardware_info['cpu_model_name'] = _get_cpu_model_name()
system_hardware_info['cpu_count'] = psutil.cpu_count()
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info('Unable to record cpu information. Continuing without it.')
gpus = GPUtil.getGPUs()
@@ -199,7 +207,7 @@ def _get_system_hardware_info() -> Dict:
system_hardware_info['gpu_model_name'] = gpus[0].name
system_hardware_info['gpu_count'] = len(gpus)
system_hardware_info['gpu_driver'] = gpus[0].driver
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info('Unable to record gpu information. Continuing without it.')
return system_hardware_info
@@ -208,11 +216,14 @@ def _get_system_hardware_info() -> Dict:
def _get_system_software_info() -> Dict:
system_software_info = {}
- system_software_info['os_platform'] = \
- platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
- system_software_info['python_version'] = platform.python_version(
+ system_software_info['os_platform'] = (
+ platform.platform()
+ ) # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
+ system_software_info['python_version'] = (
+ platform.python_version()
) # Ex. '3.11.10'
- system_software_info['python_compiler'] = platform.python_compiler(
+ system_software_info['python_compiler'] = (
+ platform.python_compiler()
) # Ex. 'GCC 9.3.0'
# Note: do not store hostname as that may be sensitive
@@ -221,26 +232,35 @@ def _get_system_software_info() -> Dict:
system_software_info['git_commit_hash'] = _get_git_commit_hash()
# Note: do not store git repo url as it may be sensitive or contain a
# secret.
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info('Unable to record git information. Continuing without it.')
return system_software_info
def _get_git_commit_hash() -> str:
- return subprocess.check_output(['git', 'rev-parse',
- 'HEAD']).decode('ascii').strip()
+ return (
+ subprocess.check_output(['git', 'rev-parse', 'HEAD'])
+ .decode('ascii')
+ .strip()
+ )
def _get_git_branch() -> str:
- return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref',
- 'HEAD']).decode('ascii').strip()
+ return (
+ subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ .decode('ascii')
+ .strip()
+ )
def _get_cpu_model_name() -> str:
output = subprocess.check_output(['lscpu']).decode('ascii').strip()
- return re.findall(r'(?=Model name:\s{1,}).*',
- output)[0].split('Model name:')[1].strip()
+ return (
+ re.findall(r'(?=Model name:\s{1,}).*', output)[0]
+ .split('Model name:')[1]
+ .strip()
+ )
def _is_primitive_type(item: Any) -> bool:
@@ -252,23 +272,25 @@ def _get_workload_properties(workload: spec.Workload) -> Dict:
workload_properties = {}
skip_list = ['param_shapes', 'model_params_types']
keys = [
- key for key in dir(workload)
- if not key.startswith('_') and key not in skip_list
+ key
+ for key in dir(workload)
+ if not key.startswith('_') and key not in skip_list
]
for key in keys:
try:
attr = getattr(workload, key)
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info(
- f'Unable to record workload.{key} information. Continuing without it.'
+ f'Unable to record workload.{key} information. Continuing without it.'
)
if _is_primitive_type(attr):
workload_properties[f'workload.{key}'] = attr
return workload_properties
-def get_meta_data(workload: spec.Workload,
- rng_seed: Optional[int] = None) -> Dict:
+def get_meta_data(
+ workload: spec.Workload, rng_seed: Optional[int] = None
+) -> Dict:
meta_data = {}
workload_properties = _get_workload_properties(workload)
meta_data.update(workload_properties)
@@ -290,12 +312,14 @@ class MetricLogger(object):
the wrong time.
"""
- def __init__(self,
- csv_path: str,
- eval_csv_path: str,
- events_dir: Optional[str] = None,
- configs: Optional[flags.FLAGS] = None,
- hyperparameters: Optional[spec.Hyperparameters] = None) -> None:
+ def __init__(
+ self,
+ csv_path: str,
+ eval_csv_path: str,
+ events_dir: Optional[str] = None,
+ configs: Optional[flags.FLAGS] = None,
+ hyperparameters: Optional[spec.Hyperparameters] = None,
+ ) -> None:
self._measurements = {}
self._csv_path = csv_path
self._eval_csv_path = eval_csv_path
@@ -305,15 +329,18 @@ def __init__(self,
self._tb_metric_writer = metric_writers.create_default_writer(events_dir)
if wandb is not None and self.use_wandb:
wandb.init(
- dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework])
+ dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework]
+ )
wandb.config.update(configs)
wandb.config.update(hyperparameters._asdict())
- def append_scalar_metrics(self,
- metrics: Dict,
- global_step: int,
- preemption_count: Optional[int] = None,
- is_eval: bool = False) -> None:
+ def append_scalar_metrics(
+ self,
+ metrics: Dict,
+ global_step: int,
+ preemption_count: Optional[int] = None,
+ is_eval: bool = False,
+ ) -> None:
metrics['global_step'] = global_step
if preemption_count is not None:
metrics['preemption_count'] = preemption_count
@@ -324,7 +351,8 @@ def append_scalar_metrics(self,
if self._tb_metric_writer:
self._tb_metric_writer.write_scalars(
- step=int(metrics['global_step']), scalars=metrics)
+ step=int(metrics['global_step']), scalars=metrics
+ )
self._tb_metric_writer.flush()
if wandb is not None and self.use_wandb:
@@ -335,15 +363,16 @@ def finish(self) -> None:
wandb.finish()
-def set_up_loggers(train_dir: str,
- configs: flags.FLAGS,
- hyperparameters: spec.Hyperparameters) -> MetricLogger:
+def set_up_loggers(
+ train_dir: str, configs: flags.FLAGS, hyperparameters: spec.Hyperparameters
+) -> MetricLogger:
csv_path = os.path.join(train_dir, 'measurements.csv')
eval_csv_path = os.path.join(train_dir, 'eval_measurements.csv')
metrics_logger = MetricLogger(
- csv_path=csv_path,
- eval_csv_path=eval_csv_path,
- events_dir=train_dir,
- configs=configs,
- hyperparameters=hyperparameters)
+ csv_path=csv_path,
+ eval_csv_path=eval_csv_path,
+ events_dir=train_dir,
+ configs=configs,
+ hyperparameters=hyperparameters,
+ )
return metrics_logger
diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py
index 05d882404..908ef0f27 100644
--- a/algoperf/param_utils.py
+++ b/algoperf/param_utils.py
@@ -14,7 +14,8 @@ def pytorch_param_shapes(model: nn.Module) -> Dict[str, spec.ShapeTuple]:
def pytorch_param_types(
- param_shapes: Dict[str, spec.ShapeTuple]) -> Dict[str, spec.ParameterType]:
+ param_shapes: Dict[str, spec.ShapeTuple],
+) -> Dict[str, spec.ParameterType]:
param_types = {}
for name in param_shapes.keys():
if 'bn' in name:
@@ -65,18 +66,21 @@ def pytorch_param_types(
def jax_param_shapes(
- params: spec.ParameterContainer) -> spec.ParameterShapeTree:
+ params: spec.ParameterContainer,
+) -> spec.ParameterShapeTree:
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)
-def jax_param_types(param_shapes: spec.ParameterShapeTree,
- parent_name: str = '') -> Dict[str, spec.ParameterType]:
+def jax_param_types(
+ param_shapes: spec.ParameterShapeTree, parent_name: str = ''
+) -> Dict[str, spec.ParameterType]:
param_types = {}
for name, value in param_shapes.items():
name = name.lower()
if isinstance(value, dict) or isinstance(value, flax.core.FrozenDict):
param_types[name] = jax_param_types(
- value, parent_name=parent_name + '/' + name)
+ value, parent_name=parent_name + '/' + name
+ )
else:
if 'batchnorm' in parent_name or 'bn' in parent_name:
if name == 'scale':
@@ -85,7 +89,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
param_types[name] = spec.ParameterType.BATCH_NORM_BIAS
else:
raise ValueError(
- f'Unrecognized batch norm parameter: {parent_name}/{name}.')
+ f'Unrecognized batch norm parameter: {parent_name}/{name}.'
+ )
elif 'layernorm' in parent_name or 'ln' in parent_name:
if name == 'scale':
param_types[name] = spec.ParameterType.LAYER_NORM_SCALE
@@ -93,7 +98,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
param_types[name] = spec.ParameterType.LAYER_NORM_BIAS
else:
raise ValueError(
- f'Unrecognized layer norm parameter: {parent_name}/{name}.')
+ f'Unrecognized layer norm parameter: {parent_name}/{name}.'
+ )
elif 'conv' in parent_name:
if 'bias' in name:
param_types[name] = spec.ParameterType.BIAS
@@ -102,8 +108,9 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
# Note that this is exact equality, not contained in, because
# flax.linen.Embed names the embedding parameter "embedding"
# https://github.com/google/flax/blob/main/flax/linen/linear.py#L604.
- elif ('embedding' in name or
- ('embedding' in parent_name and name == 'kernel')):
+ elif 'embedding' in name or (
+ 'embedding' in parent_name and name == 'kernel'
+ ):
param_types[name] = spec.ParameterType.EMBEDDING
elif 'attention' in parent_name:
if name == 'bias':
@@ -122,7 +129,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
param_types[name] = spec.ParameterType.ATTENTION_QKV
else:
raise ValueError(
- f'Unrecognized attention parameter: {parent_name}/{name}.')
+ f'Unrecognized attention parameter: {parent_name}/{name}.'
+ )
elif 'bias' in name:
param_types[name] = spec.ParameterType.BIAS
else:
diff --git a/algoperf/profiler.py b/algoperf/profiler.py
index fa2a1bee2..534a5ccfb 100644
--- a/algoperf/profiler.py
+++ b/algoperf/profiler.py
@@ -4,10 +4,10 @@
https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers.
"""
-from collections import defaultdict
-from contextlib import contextmanager
import os
import time
+from collections import defaultdict
+from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple
import numpy as np
@@ -21,7 +21,6 @@ def _get_monotonic_time() -> float:
class Profiler:
-
def __init__(self, local_rank: Optional[int] = None) -> None:
self._local_rank = local_rank
@@ -41,7 +40,8 @@ def start(self, action_name: str) -> None:
pass
if action_name in self.current_actions:
raise ValueError(
- f'Attempted to start {action_name} which has already started.')
+ f'Attempted to start {action_name} which has already started.'
+ )
self.current_actions[action_name] = _get_monotonic_time()
def stop(self, action_name: str) -> None:
@@ -49,8 +49,10 @@ def stop(self, action_name: str) -> None:
pass
end_time = _get_monotonic_time()
if action_name not in self.current_actions:
- raise ValueError(f'Attempting to stop recording an action '
- f'({action_name}) which was never started.')
+ raise ValueError(
+ f'Attempting to stop recording an action '
+ f'({action_name}) which was never started.'
+ )
start_time = self.current_actions.pop(action_name)
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)
@@ -64,16 +66,20 @@ def profile(self, action_name: str) -> Generator:
self.stop(action_name)
def _make_report(
- self
+ self,
) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]:
total_duration = _get_monotonic_time() - self.start_time
- report = [(str(a),
- float(np.mean(d)),
- float(np.std(d)),
- len(d),
- float(np.sum(d)),
- 100.0 * float(np.sum(d)) / total_duration) for a,
- d in self.recorded_durations.items()]
+ report = [
+ (
+ str(a),
+ float(np.mean(d)),
+ float(np.std(d)),
+ len(d),
+ float(np.sum(d)),
+ 100.0 * float(np.sum(d)) / total_duration,
+ )
+ for a, d in self.recorded_durations.items()
+ ]
report.sort(key=lambda x: x[5], reverse=True)
total_calls = sum(x[3] for x in report)
return report, total_calls, total_duration
@@ -92,32 +98,42 @@ def log_row(action, mean, std, num_calls, total, per):
row += f' {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|'
return row
- header_string = log_row('Action',
- 'Mean Duration (s)',
- 'Std Duration (s)',
- 'Num Calls',
- 'Total Time (s)',
- 'Percentage %')
+ header_string = log_row(
+ 'Action',
+ 'Mean Duration (s)',
+ 'Std Duration (s)',
+ 'Num Calls',
+ 'Total Time (s)',
+ 'Percentage %',
+ )
output_string_len = len(header_string.expandtabs())
sep_lines = f'{sep}{"-" * output_string_len}'
output_string += sep_lines + header_string + sep_lines
report, total_calls, total_duration = self._make_report()
- output_string += log_row('Total',
- '-----',
- '-----',
- f'{total_calls:}',
- f'{total_duration:.5}',
- '100 %')
+ output_string += log_row(
+ 'Total',
+ '-----',
+ '-----',
+ f'{total_calls:}',
+ f'{total_duration:.5}',
+ '100 %',
+ )
output_string += sep_lines
- for action, mean_duration, std_duration, num_calls, \
- total_duration, duration_per in report:
+ for (
+ action,
+ mean_duration,
+ std_duration,
+ num_calls,
+ total_duration,
+ duration_per,
+ ) in report:
output_string += log_row(
- action,
- f'{mean_duration:.5}',
- f'{std_duration:.5}',
- f'{num_calls}',
- f'{total_duration:.5}',
- f'{duration_per:.5}',
+ action,
+ f'{mean_duration:.5}',
+ f'{std_duration:.5}',
+ f'{num_calls}',
+ f'{total_duration:.5}',
+ f'{duration_per:.5}',
)
output_string += sep_lines
output_string += sep
@@ -125,7 +141,6 @@ def log_row(action, mean, std, num_calls, total, per):
class PassThroughProfiler(Profiler):
-
def start(self, action_name: str) -> None:
pass
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index 4a674985d..af09e67fc 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -1,18 +1,22 @@
import os
from typing import Tuple
-from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist
+import torch.nn.functional as F
+from absl import logging
+from torch import Tensor, nn
from algoperf import spec
from algoperf.profiler import Profiler
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- BatchNorm as ConformerBatchNorm
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- BatchNorm as DeepspeechBatchNorm
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ BatchNorm as ConformerBatchNorm,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
+ BatchNorm as DeepspeechBatchNorm,
+)
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
@@ -58,12 +62,13 @@ def sync_ddp_time(time: float, device: torch.device) -> float:
return time_tensor.item()
-def update_batch_norm_fn(module: spec.ParameterContainer,
- update_batch_norm: bool) -> None:
+def update_batch_norm_fn(
+ module: spec.ParameterContainer, update_batch_norm: bool
+) -> None:
bn_layers = (
- torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class.
- ConformerBatchNorm, # Custom BN class for conformer model.
- DeepspeechBatchNorm, # Custom BN class for deepspeech model.
+ torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class.
+ ConformerBatchNorm, # Custom BN class for conformer model.
+ DeepspeechBatchNorm, # Custom BN class for deepspeech model.
)
if isinstance(module, bn_layers):
if not update_batch_norm:
@@ -77,3 +82,41 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
+
+
+class CustomDropout(nn.Module):
+ """A module around torch.nn.functional.dropout."""
+
+ def __init__(self):
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ def forward(self, x: Tensor, p: float) -> Tensor:
+ return F.dropout(x, p, training=self.training)
+
+
+class CustomDropout2d(nn.Module):
+ """A module around torch.nn.functional.dropout2d."""
+
+ def __init__(self):
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ def forward(self, x: Tensor, p: float) -> Tensor:
+ return F.dropout2d(x, p, training=self.training)
+
+
+class SequentialWithDropout(nn.Sequential):
+ """Sequential of modules with dropout."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._supports_custom_dropout = True
+
+ def forward(self, x: Tensor, p: float) -> Tensor:
+ for module in self:
+ if getattr(module, '_supports_custom_dropout', False):
+ x = module(x, p)
+ else:
+ x = module(x)
+ return x
diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py
index a579976ad..1dc773e80 100644
--- a/algoperf/random_utils.py
+++ b/algoperf/random_utils.py
@@ -2,16 +2,16 @@
from typing import Any, List, Union
-from absl import flags
-from absl import logging
import numpy as np
+from absl import flags, logging
try:
import jax.random as jax_rng
except (ImportError, ModuleNotFoundError):
logging.warning(
- 'Could not import jax.random for the submission runner, falling back to '
- 'numpy random_utils.')
+ 'Could not import jax.random for the submission runner, falling back to '
+ 'numpy random_utils.'
+ )
jax_rng = None
FLAGS = flags.FLAGS
@@ -54,8 +54,9 @@ def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
def _check_jax_install() -> None:
if jax_rng is None:
raise ValueError(
- 'Must install jax to use the jax RNG library, or use PyTorch and pass '
- '--framework=pytorch to use the Numpy version instead.')
+ 'Must install jax to use the jax RNG library, or use PyTorch and pass '
+ '--framework=pytorch to use the Numpy version instead.'
+ )
def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
diff --git a/algoperf/spec.py b/algoperf/spec.py
index cf4f1a14e..5f7b930af 100644
--- a/algoperf/spec.py
+++ b/algoperf/spec.py
@@ -5,10 +5,10 @@
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
-from absl import logging
import jax
-from torch import nn
import torch.nn.functional as F
+from absl import logging
+from torch import nn
class LossType(enum.Enum):
@@ -53,7 +53,6 @@ class ParameterType(enum.Enum):
# Define this so that if using pytree iteration utilities, can iterate over the
# model shapes pytree without iterating over the shape tuples.
class ShapeTuple:
-
def __init__(self, shape_tuple):
self.shape_tuple = shape_tuple
@@ -64,19 +63,22 @@ def __eq__(self, other):
return self.shape_tuple == other.shape_tuple
-Shape = Union[Tuple[int],
- Tuple[int, int],
- Tuple[int, int, int],
- Tuple[int, int, int, int],
- ShapeTuple]
+Shape = Union[
+ Tuple[int],
+ Tuple[int, int],
+ Tuple[int, int, int],
+ Tuple[int, int, int, int],
+ ShapeTuple,
+]
ParameterShapeTree = Dict[str, Dict[str, Shape]]
# If necessary, these can be zipped together easily given they have the same
# structure, to get an iterator over pairs of leaves.
ParameterKey = str
# Dicts can be arbitrarily nested.
-ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]],
- nn.Module]
+ParameterContainer = Union[
+ Dict[ParameterKey, Dict[ParameterKey, Tensor]], nn.Module
+]
ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]]
RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...]
@@ -92,7 +94,6 @@ def __eq__(self, other):
class Workload(metaclass=abc.ABCMeta):
-
def __init__(self, *args, **kwargs) -> None:
del args
del kwargs
@@ -107,8 +108,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
@abc.abstractmethod
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
"""Return whether or not the workload validation goal has been reached."""
@abc.abstractmethod
@@ -117,14 +119,15 @@ def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
@abc.abstractmethod
def _build_input_queue(
- self,
- data_rng: RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, Any]]:
+ self,
+ data_rng: RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, Any]]:
"""Build the input queue for the workload data.
This is the only function that is NOT allowed to be called by submitters.
@@ -213,8 +216,9 @@ def param_shapes(self):
"""The shapes of the parameters in the workload model."""
if self._param_shapes is None:
raise ValueError(
- 'This should not happen, workload.init_model_fn() should be called '
- 'before workload.param_shapes!')
+ 'This should not happen, workload.init_model_fn() should be called '
+ 'before workload.param_shapes!'
+ )
return self._param_shapes
@property
@@ -222,8 +226,9 @@ def model_params_types(self):
"""The types of the parameters in the workload model."""
if self._param_types is None:
raise ValueError(
- 'This should not happen, workload.init_model_fn() should be called '
- 'before workload.param_types!')
+ 'This should not happen, workload.init_model_fn() should be called '
+ 'before workload.param_types!'
+ )
return self._param_types
@abc.abstractmethod
@@ -234,10 +239,12 @@ def is_output_params(self, param_key: ParameterKey) -> bool:
# Tuple[RandomState, Optional[float], Optional[float]],
# ParameterContainer]
@abc.abstractmethod
- def init_model_fn(self,
- rng: RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> ModelInitState:
+ def init_model_fn(
+ self,
+ rng: RandomState,
+ dropout_rate: Optional[float] = None,
+ aux_dropout_rate: Optional[float] = None,
+ ) -> ModelInitState:
"""Return (initial_params, initial_model_state)."""
# ModelFn = Callable[
@@ -247,32 +254,39 @@ def init_model_fn(self,
# ModelAuxiliaryState,
# ForwardPassMode,
# RandomState,
- # bool],
+ # bool,
+ # float],
# Tensor]
@abc.abstractmethod
- def model_fn(self,
- params: ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, Tensor],
- model_state: ModelAuxiliaryState,
- mode: ForwardPassMode,
- rng: RandomState,
- update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
+ def model_fn(
+ self,
+ params: ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, Tensor],
+ model_state: ModelAuxiliaryState,
+ mode: ForwardPassMode,
+ rng: RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float,
+ ) -> Tuple[Tensor, ModelAuxiliaryState]:
"""Return logits_batch"""
# Possible side effect of updating BN.
- def output_activation_fn(self, logits_batch: Tensor,
- framework: str) -> Tensor:
+ def output_activation_fn(
+ self, logits_batch: Tensor, framework: str
+ ) -> Tensor:
"""Turn logits into probabilities, according to the loss_type property."""
if framework not in ['pytorch', 'jax']:
raise ValueError(
- f'`framework` has to be either `pytorch` or `jax`, got {framework}.')
+ f'`framework` has to be either `pytorch` or `jax`, got {framework}.'
+ )
activation_fn = {
- LossType.MEAN_SQUARED_ERROR: lambda z: z,
- LossType.MEAN_ABSOLUTE_ERROR: lambda z: z,
+ LossType.MEAN_SQUARED_ERROR: lambda z: z,
+ LossType.MEAN_ABSOLUTE_ERROR: lambda z: z,
}
is_pytorch = framework == 'pytorch' # If False, framework == 'jax'.
softmax_fn = (
- functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax)
+ functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax
+ )
sigmoid_fn = F.sigmoid if is_pytorch else jax.nn.sigmoid
activation_fn[LossType.SOFTMAX_CROSS_ENTROPY] = softmax_fn
activation_fn[LossType.SIGMOID_CROSS_ENTROPY] = sigmoid_fn
@@ -284,12 +298,13 @@ def output_activation_fn(self, logits_batch: Tensor,
# `update_params`.
@abc.abstractmethod
def loss_fn(
- self,
- # Dense or one-hot labels, or a tuple of (tensor, padding) for speech.
- label_batch: Union[Tuple[Tensor, Tensor], Tensor],
- logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
- mask_batch: Optional[Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, Tensor]: # differentiable
+ self,
+ # Dense or one-hot labels, or a tuple of (tensor, padding) for speech.
+ label_batch: Union[Tuple[Tensor, Tensor], Tensor],
+ logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
+ mask_batch: Optional[Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -298,48 +313,54 @@ def loss_fn(
"""
@abc.abstractmethod
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: ParameterContainer,
- model_state: ModelAuxiliaryState,
- rng: RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ rng: RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Evaluate the model on a given dataset split, return final scalars."""
- def eval_model(self,
- global_batch_size: int,
- params: ParameterContainer,
- model_state: ModelAuxiliaryState,
- rng: RandomState,
- data_dir: str,
- imagenet_v2_data_dir: Optional[str],
- global_step: int) -> Dict[str, float]:
+ def eval_model(
+ self,
+ global_batch_size: int,
+ params: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ rng: RandomState,
+ data_dir: str,
+ imagenet_v2_data_dir: Optional[str],
+ global_step: int,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
logging.info('Evaluating on the training split.')
train_metrics = self._eval_model_on_split(
- split='eval_train',
- num_examples=self.num_eval_train_examples,
- global_batch_size=global_batch_size,
- params=params,
- model_state=model_state,
- rng=rng,
- data_dir=data_dir,
- global_step=global_step)
+ split='eval_train',
+ num_examples=self.num_eval_train_examples,
+ global_batch_size=global_batch_size,
+ params=params,
+ model_state=model_state,
+ rng=rng,
+ data_dir=data_dir,
+ global_step=global_step,
+ )
eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
# We always require a validation set.
logging.info('Evaluating on the validation split.')
validation_metrics = self._eval_model_on_split(
- 'validation',
- num_examples=self.num_validation_examples,
- global_batch_size=global_batch_size,
- params=params,
- model_state=model_state,
- rng=rng,
- data_dir=data_dir,
- global_step=global_step)
+ 'validation',
+ num_examples=self.num_validation_examples,
+ global_batch_size=global_batch_size,
+ params=params,
+ model_state=model_state,
+ rng=rng,
+ data_dir=data_dir,
+ global_step=global_step,
+ )
for k, v in validation_metrics.items():
eval_metrics['validation/' + k] = v
eval_metrics['validation/num_examples'] = self.num_validation_examples
@@ -348,14 +369,15 @@ def eval_model(self,
if self.num_test_examples is not None:
logging.info('Evaluating on the test split.')
test_metrics = self._eval_model_on_split(
- 'test',
- num_examples=self.num_test_examples,
- global_batch_size=global_batch_size,
- params=params,
- model_state=model_state,
- rng=rng,
- data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir,
- global_step=global_step)
+ 'test',
+ num_examples=self.num_test_examples,
+ global_batch_size=global_batch_size,
+ params=params,
+ model_state=model_state,
+ rng=rng,
+ data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir,
+ global_step=global_step,
+ )
for k, v in test_metrics.items():
eval_metrics['test/' + k] = v
eval_metrics['test/num_examples'] = self.num_test_examples
@@ -372,27 +394,32 @@ class TrainingCompleteError(Exception):
# Training algorithm track submission functions, to be filled in by the
# submitter.
-InitOptimizerFn = Callable[[
+InitOptimizerFn = Callable[
+ [
Workload,
ParameterContainer,
ModelAuxiliaryState,
Hyperparameters,
- RandomState
-],
- OptimizerState]
-
-
-def init_optimizer_state(workload: Workload,
- model_params: ParameterContainer,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- rng: RandomState) -> OptimizerState:
+ RandomState,
+ ],
+ OptimizerState,
+]
+
+
+def init_optimizer_state(
+ workload: Workload,
+ model_params: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ rng: RandomState,
+) -> OptimizerState:
# return initial_optimizer_state
pass
UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState]
-UpdateParamsFn = Callable[[
+UpdateParamsFn = Callable[
+ [
Workload,
ParameterContainer,
ParameterTypeTree,
@@ -404,9 +431,10 @@ def init_optimizer_state(workload: Workload,
List[Tuple[int, float]],
int,
RandomState,
- Optional[Dict[str, Any]]
-],
- UpdateReturn]
+ Optional[Dict[str, Any]],
+ ],
+ UpdateReturn,
+]
# Each call to this function is considered a "step".
@@ -415,23 +443,26 @@ def init_optimizer_state(workload: Workload,
# and if has not actually achieved the goal then it will be considered as not
# achieved the goal and get an infinite time score. Most submissions will likely
# wait until the next free eval and not use this functionality.
-def update_params(workload: Workload,
- current_param_container: ParameterContainer,
- current_params_types: ParameterTypeTree,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- batch: Dict[str, Tensor],
- loss_type: LossType,
- optimizer_state: OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
+def update_params(
+ workload: Workload,
+ current_param_container: ParameterContainer,
+ current_params_types: ParameterTypeTree,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ batch: Dict[str, Tensor],
+ loss_type: LossType,
+ optimizer_state: OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass
-PrepareForEvalFn = Callable[[
+PrepareForEvalFn = Callable[
+ [
Workload,
ParameterContainer,
ParameterTypeTree,
@@ -441,27 +472,31 @@ def update_params(workload: Workload,
OptimizerState,
List[Tuple[int, float]],
int,
- RandomState
-],
- UpdateReturn]
+ RandomState,
+ ],
+ UpdateReturn,
+]
# Prepare model and optimizer for evaluation.
-def prepare_for_eval(workload: Workload,
- current_param_container: ParameterContainer,
- current_params_types: ParameterTypeTree,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- loss_type: LossType,
- optimizer_state: OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: RandomState) -> UpdateReturn:
+def prepare_for_eval(
+ workload: Workload,
+ current_param_container: ParameterContainer,
+ current_params_types: ParameterTypeTree,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ loss_type: LossType,
+ optimizer_state: OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: RandomState,
+) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass
-DataSelectionFn = Callable[[
+DataSelectionFn = Callable[
+ [
Workload,
Iterator[Dict[str, Any]],
OptimizerState,
@@ -469,21 +504,24 @@ def prepare_for_eval(workload: Workload,
LossType,
Hyperparameters,
int,
- RandomState
-],
- Tuple[Tensor, Tensor]]
+ RandomState,
+ ],
+ Tuple[Tensor, Tensor],
+]
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: Workload,
- input_queue: Iterator[Dict[str, Any]],
- optimizer_state: OptimizerState,
- current_param_container: ParameterContainer,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- global_step: int,
- rng: RandomState) -> Dict[str, Tensor]:
+def data_selection(
+ workload: Workload,
+ input_queue: Iterator[Dict[str, Any]],
+ optimizer_state: OptimizerState,
+ current_param_container: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ global_step: int,
+ rng: RandomState,
+) -> Dict[str, Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
index 728d05f29..7fbc95bc6 100644
--- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
+++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
@@ -8,22 +8,24 @@
import functools
from typing import Dict, Iterator, Tuple
-from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
+from flax import jax_utils
from algoperf import spec
from algoperf.data_utils import shard_and_maybe_pad_np
-def preprocess_for_train(image: spec.Tensor,
- rng: spec.RandomState,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- crop_size: int,
- padding_size: int,
- dtype: tf.DType = tf.float32) -> spec.Tensor:
+def preprocess_for_train(
+ image: spec.Tensor,
+ rng: spec.RandomState,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ crop_size: int,
+ padding_size: int,
+ dtype: tf.DType = tf.float32,
+) -> spec.Tensor:
"""Preprocesses the given image for training.
Args:
@@ -44,20 +46,23 @@ def preprocess_for_train(image: spec.Tensor,
flip_rng = rng[1, :]
image_shape = tf.shape(image)
- image = tf.image.resize_with_crop_or_pad(image,
- image_shape[0] + padding_size,
- image_shape[1] + padding_size)
+ image = tf.image.resize_with_crop_or_pad(
+ image, image_shape[0] + padding_size, image_shape[1] + padding_size
+ )
image = tf.image.stateless_random_crop(
- image, (crop_size, crop_size, 3), seed=crop_rng)
+ image, (crop_size, crop_size, 3), seed=crop_rng
+ )
image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng)
image = normalize_image(image, mean_rgb, stddev_rgb, dtype=dtype)
return image
-def preprocess_for_eval(image: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- dtype: tf.DType = tf.float32) -> spec.Tensor:
+def preprocess_for_eval(
+ image: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ dtype: tf.DType = tf.float32,
+) -> spec.Tensor:
"""Preprocesses the given image for evaluation.
Args:
@@ -74,10 +79,12 @@ def preprocess_for_eval(image: spec.Tensor,
return image
-def normalize_image(image: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- dtype=tf.float32) -> spec.Tensor:
+def normalize_image(
+ image: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ dtype=tf.float32,
+) -> spec.Tensor:
image = tf.image.convert_image_dtype(image, dtype)
image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
@@ -85,17 +92,17 @@ def normalize_image(image: spec.Tensor,
def create_split(
- split: str,
- dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
- rng: spec.RandomState,
- global_batch_size: int,
- train: bool,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- cache: bool = False,
- repeat_final_dataset: bool = False,
- crop_size: int = 32,
- padding_size: int = 4,
+ split: str,
+ dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
+ rng: spec.RandomState,
+ global_batch_size: int,
+ train: bool,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ cache: bool = False,
+ repeat_final_dataset: bool = False,
+ crop_size: int = 32,
+ padding_size: int = 4,
) -> Iterator[Dict[str, spec.Tensor]]:
"""Creates a split from the CIFAR-10 dataset using TensorFlow Datasets."""
shuffle_rng, preprocess_rng = jax.random.split(rng, 2)
@@ -104,14 +111,17 @@ def preprocess_example(example_index, example):
dtype = tf.float32
if train:
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
- tf.cast(preprocess_rng, tf.int64), example_index)
- image = preprocess_for_train(example['image'],
- per_step_preprocess_rng,
- mean_rgb,
- stddev_rgb,
- crop_size,
- padding_size,
- dtype)
+ tf.cast(preprocess_rng, tf.int64), example_index
+ )
+ image = preprocess_for_train(
+ example['image'],
+ per_step_preprocess_rng,
+ mean_rgb,
+ stddev_rgb,
+ crop_size,
+ padding_size,
+ dtype,
+ )
else:
image = preprocess_for_eval(example['image'], mean_rgb, stddev_rgb, dtype)
return {'inputs': image, 'targets': example['label']}
@@ -132,7 +142,8 @@ def preprocess_example(example_index, example):
# index that we can fold into the RNG seed.
ds = ds.enumerate()
ds = ds.map(
- preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE
+ )
ds = ds.batch(global_batch_size, drop_remainder=train)
if repeat_final_dataset:
@@ -144,32 +155,36 @@ def preprocess_example(example_index, example):
def create_input_iter(
- split: str,
- dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
- rng: spec.RandomState,
- global_batch_size: int,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- crop_size: int,
- padding_size: int,
- train: bool,
- cache: bool,
- repeat_final_dataset: bool) -> Iterator[Dict[str, spec.Tensor]]:
+ split: str,
+ dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
+ rng: spec.RandomState,
+ global_batch_size: int,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ crop_size: int,
+ padding_size: int,
+ train: bool,
+ cache: bool,
+ repeat_final_dataset: bool,
+) -> Iterator[Dict[str, spec.Tensor]]:
ds = create_split(
- split,
- dataset_builder,
- rng,
- global_batch_size,
- train=train,
- mean_rgb=mean_rgb,
- stddev_rgb=stddev_rgb,
- cache=cache,
- repeat_final_dataset=repeat_final_dataset,
- crop_size=crop_size,
- padding_size=padding_size)
+ split,
+ dataset_builder,
+ rng,
+ global_batch_size,
+ train=train,
+ mean_rgb=mean_rgb,
+ stddev_rgb=stddev_rgb,
+ cache=cache,
+ repeat_final_dataset=repeat_final_dataset,
+ crop_size=crop_size,
+ padding_size=padding_size,
+ )
it = map(
- functools.partial(
- shard_and_maybe_pad_np, global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
it = jax_utils.prefetch_to_device(it, 2)
return it
diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py
index 957079272..95238c997 100644
--- a/algoperf/workloads/cifar/cifar_jax/models.py
+++ b/algoperf/workloads/cifar/cifar_jax/models.py
@@ -7,8 +7,8 @@
import functools
from typing import Any, Callable, Tuple
-from flax import linen as nn
import jax.numpy as jnp
+from flax import linen as nn
from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock
@@ -25,48 +25,52 @@ class ResNet(nn.Module):
act: Callable = nn.relu
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- update_batch_norm: bool = True,
- use_running_average_bn: bool = None) -> spec.Tensor:
+ def __call__(
+ self,
+ x: spec.Tensor,
+ update_batch_norm: bool = True,
+ use_running_average_bn: bool = None,
+ ) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
- nn.BatchNorm,
- use_running_average=use_running_average_bn,
- momentum=0.9,
- epsilon=1e-5,
- dtype=self.dtype)
+ nn.BatchNorm,
+ use_running_average=use_running_average_bn,
+ momentum=0.9,
+ epsilon=1e-5,
+ dtype=self.dtype,
+ )
x = conv(
- self.num_filters, (3, 3), (1, 1),
- padding=[(1, 1), (1, 1)],
- name='Conv_init')(
- x)
+ self.num_filters,
+ (3, 3),
+ (1, 1),
+ padding=[(1, 1), (1, 1)],
+ name='Conv_init',
+ )(x)
x = norm(name='BatchNorm_init')(x)
x = nn.relu(x)
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
- self.num_filters * 2**i,
- strides=strides,
- conv=conv,
- norm=norm,
- act=self.act)(
- x)
+ self.num_filters * 2**i,
+ strides=strides,
+ conv=conv,
+ norm=norm,
+ act=self.act,
+ )(x)
x = nn.avg_pool(x, (4, 4), strides=(4, 4))
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(
- self.num_classes,
- kernel_init=nn.initializers.normal(),
- dtype=self.dtype)(
- x)
+ self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
+ )(x)
return x
ResNet18 = functools.partial(
- ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
+ ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock
+)
diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py
index ad43bc62f..bc26e3899 100644
--- a/algoperf/workloads/cifar/cifar_jax/workload.py
+++ b/algoperf/workloads/cifar/cifar_jax/workload.py
@@ -3,32 +3,30 @@
import functools
from typing import Any, Dict, Iterator, Optional, Tuple
-from flax import jax_utils
-from flax import linen as nn
-from flax.core import pop
import jax
-from jax import lax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
+from flax import jax_utils
+from flax import linen as nn
+from flax.core import pop
+from jax import lax
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.cifar.cifar_jax import models
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
from algoperf.workloads.cifar.workload import BaseCifarWorkload
class CifarWorkload(BaseCifarWorkload):
-
def _build_cifar_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
train = split == 'train'
@@ -38,38 +36,38 @@ def _build_cifar_dataset(
elif split == 'validation':
split = f'train[{self.num_train_examples}:]'
ds = create_input_iter(
- split,
- ds_builder,
- data_rng,
- batch_size,
- self.train_mean,
- self.train_stddev,
- self.crop_size,
- self.padding_size,
- train=train,
- cache=not train if cache is None else cache,
- repeat_final_dataset=repeat_final_dataset)
+ split,
+ ds_builder,
+ data_rng,
+ batch_size,
+ self.train_mean,
+ self.train_stddev,
+ self.crop_size,
+ self.padding_size,
+ train=train,
+ cache=not train if cache is None else cache,
+ repeat_final_dataset=repeat_final_dataset,
+ )
return ds
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
- return self._build_cifar_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset)
+ return self._build_cifar_dataset(
+ data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
+ )
def sync_batch_stats(
- self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
+ self, model_state: spec.ModelAuxiliaryState
+ ) -> spec.ModelAuxiliaryState:
"""Sync the batch statistics across replicas."""
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics
@@ -79,20 +77,15 @@ def sync_batch_stats(
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
model_cls = getattr(models, 'ResNet18')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
input_shape = (1, 32, 32, 3)
- variables = jax.jit(model.init)({'params': rng},
- jnp.ones(input_shape, model.dtype))
+ variables = jax.jit(model.init)(
+ {'params': rng}, jnp.ones(input_shape, model.dtype)
+ )
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -104,43 +97,46 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=['batch_stats'],
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, new_model_state
else:
logits = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=False,
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=False,
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, model_state
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -150,7 +146,8 @@ def loss_fn(
one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
- smoothed_targets * nn.log_softmax(logits_batch), axis=-1)
+ smoothed_targets * nn.log_softmax(logits_batch), axis=-1
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -159,51 +156,53 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
- def _compute_metrics(self,
- logits: spec.Tensor,
- labels: spec.Tensor,
- weights: spec.Tensor) -> Dict[str, spec.Tensor]:
+ def _compute_metrics(
+ self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
+ ) -> Dict[str, spec.Tensor]:
summed_loss = self.loss_fn(labels, logits, weights)['summed']
# Number of correct predictions.
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
- 'loss': summed_loss,
- 'accuracy': accuracy,
+ 'loss': summed_loss,
+ 'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py
index e6a7a8a81..0e08f5c5a 100644
--- a/algoperf/workloads/cifar/cifar_pytorch/models.py
+++ b/algoperf/workloads/cifar/cifar_pytorch/models.py
@@ -12,23 +12,24 @@
from algoperf import spec
from algoperf.init_utils import pytorch_default_init
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- BasicBlock
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- Bottleneck
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ BasicBlock,
+ Bottleneck,
+ conv1x1,
+)
class ResNet(nn.Module):
-
- def __init__(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- layers: List[int],
- num_classes: int = 10,
- groups: int = 1,
- width_per_group: int = 64,
- replace_stride_with_dilation: Optional[List[bool]] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 10,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -42,21 +43,26 @@ def __init__(self,
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
- 'replace_stride_with_dilation should be None '
- f'or a 3-element tuple, got {replace_stride_with_dilation}')
+ 'replace_stride_with_dilation should be None '
+ f'or a 3-element tuple, got {replace_stride_with_dilation}'
+ )
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
- 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
+ )
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
- block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
+ )
self.layer3 = self._make_layer(
- block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
+ )
self.layer4 = self._make_layer(
- block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
+ )
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.reset_parameters()
@@ -68,7 +74,7 @@ def reset_parameters(self) -> None:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.fc.weight, std=1e-2)
- nn.init.constant_(self.fc.bias, 0.)
+ nn.init.constant_(self.fc.bias, 0.0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros,
@@ -81,12 +87,14 @@ def reset_parameters(self) -> None:
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
- def _make_layer(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- planes: int,
- blocks: int,
- stride: int = 1,
- dilate: bool = False) -> nn.Sequential:
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,
+ ) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
@@ -95,32 +103,39 @@ def _make_layer(self,
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = torch.nn.Sequential(
- collections.OrderedDict([
- ("conv", conv1x1(self.inplanes, planes * block.expansion,
- stride)),
- ("bn", norm_layer(planes * block.expansion)),
- ]))
+ collections.OrderedDict(
+ [
+ ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)),
+ ('bn', norm_layer(planes * block.expansion)),
+ ]
+ )
+ )
layers = []
layers.append(
- block(self.inplanes,
- planes,
- stride,
- downsample,
- self.groups,
- self.base_width,
- previous_dilation,
- norm_layer))
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer,
+ )
+ )
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
- block(
- self.inplanes,
- planes,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm_layer=norm_layer))
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
return nn.Sequential(*layers)
diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py
index d05131c27..f1189bebc 100644
--- a/algoperf/workloads/cifar/cifar_pytorch/workload.py
+++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py
@@ -12,10 +12,7 @@
from torchvision import transforms
from torchvision.datasets import CIFAR10
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
+from algoperf import data_utils, param_utils, pytorch_utils, spec
from algoperf.workloads.cifar.cifar_pytorch.models import resnet18
from algoperf.workloads.cifar.workload import BaseCifarWorkload
@@ -23,7 +20,6 @@
class CifarWorkload(BaseCifarWorkload):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Is set in submission_runner.py for workloads with PyTorch evaluation
@@ -34,7 +30,8 @@ def __init__(self, *args, **kwargs) -> None:
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -42,47 +39,54 @@ def eval_num_workers(self, eval_num_workers: int):
self._eval_num_workers = eval_num_workers
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
) -> torch.utils.data.DataLoader:
del cache
del repeat_final_dataset
is_train = split == 'train'
- normalize = transforms.Compose([
+ normalize = transforms.Compose(
+ [
transforms.ToTensor(),
transforms.Normalize(mean=self.train_mean, std=self.train_stddev),
- ])
+ ]
+ )
eval_transform_config = normalize
- train_transform_config = transforms.Compose([
+ train_transform_config = transforms.Compose(
+ [
transforms.RandomCrop(
- size=self.crop_size,
- padding=self.padding_size,
+ size=self.crop_size,
+ padding=self.padding_size,
),
transforms.RandomHorizontalFlip(),
normalize,
- ])
+ ]
+ )
transform = train_transform_config if is_train else eval_transform_config
dataset = CIFAR10(
- root=data_dir,
- train=split in ['train', 'eval_train', 'validation'],
- download=False,
- transform=transform)
+ root=data_dir,
+ train=split in ['train', 'eval_train', 'validation'],
+ download=False,
+ transform=transform,
+ )
assert self.num_train_examples + self.num_validation_examples == 50000
indices = list(range(50000))
indices_split = {
- 'train': indices[:self.num_train_examples],
- 'validation': indices[self.num_train_examples:],
+ 'train': indices[: self.num_train_examples],
+ 'validation': indices[self.num_train_examples :],
}
if split == 'eval_train':
train_indices = indices_split['train']
random.Random(int(data_rng[0])).shuffle(train_indices)
- indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
+ indices_split['eval_train'] = train_indices[
+ : self.num_eval_train_examples
+ ]
if split in indices_split:
dataset = torch.utils.data.Subset(dataset, indices_split[split])
@@ -92,30 +96,34 @@ def _build_dataset(
ds_iter_batch_size = per_device_batch_size
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True
+ )
else:
sampler = data_utils.DistributedEvalSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False
+ )
else:
ds_iter_batch_size = global_batch_size
dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=ds_iter_batch_size,
- shuffle=not USE_PYTORCH_DDP and is_train,
- sampler=sampler,
- num_workers=4 if is_train else self.eval_num_workers,
- pin_memory=True,
- drop_last=is_train)
+ dataset,
+ batch_size=ds_iter_batch_size,
+ shuffle=not USE_PYTORCH_DDP and is_train,
+ sampler=sampler,
+ num_workers=4 if is_train else self.eval_num_workers,
+ pin_memory=True,
+ drop_last=is_train,
+ )
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
return dataloader
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ self,
+ rng: spec.RandomState,
+ dropout_rate: Optional[float] = None,
+ aux_dropout_rate: Optional[float] = None,
+ ) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
@@ -143,30 +151,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['fc.weight', 'fc.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
model = params
if mode == spec.ForwardPassMode.EVAL:
if update_batch_norm:
raise ValueError(
- 'Batch norm statistics cannot be updated during evaluation.')
+ 'Batch norm statistics cannot be updated during evaluation.'
+ )
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
- functools.partial(
- pytorch_utils.update_batch_norm_fn,
- update_batch_norm=update_batch_norm))
+ functools.partial(
+ pytorch_utils.update_batch_norm_fn,
+ update_batch_norm=update_batch_norm,
+ )
+ )
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
@@ -175,11 +187,12 @@ def model_fn(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -187,10 +200,11 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = F.cross_entropy(
- logits_batch,
- label_batch,
- reduction='none',
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ reduction='none',
+ label_smoothing=label_smoothing,
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -199,25 +213,27 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
@@ -229,8 +245,8 @@ def _eval_model(
return {'accuracy': accuracy, 'loss': summed_loss}
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py
index c0d565108..31636807c 100644
--- a/algoperf/workloads/cifar/workload.py
+++ b/algoperf/workloads/cifar/workload.py
@@ -7,15 +7,14 @@
import jax
import torch
+import algoperf.random_utils as prng
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
-import algoperf.random_utils as prng
USE_PYTORCH_DDP, _, _, _ = pytorch_setup()
class BaseCifarWorkload(spec.Workload):
-
_num_classes: int = 10
@property
@@ -23,8 +22,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'accuracy'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/accuracy'] > self.validation_target_value
@property
@@ -51,8 +51,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -93,37 +94,35 @@ def eval_period_time_sec(self) -> int:
return 600 # 10 mins.
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
raise NotImplementedError
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
if split == 'test':
if not cache:
raise ValueError('cache must be True for split=test.')
if not repeat_final_dataset:
raise ValueError('repeat_final_dataset must be True for split=test.')
- return self._build_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset)
+ return self._build_dataset(
+ data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
+ )
@property
def step_hint(self) -> int:
@@ -133,39 +132,43 @@ def step_hint(self) -> int:
return 4883
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
raise NotImplementedError
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng=data_rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- cache=True,
- repeat_final_dataset=True)
+ data_rng=data_rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ cache=True,
+ repeat_final_dataset=True,
+ )
num_batches = int(math.ceil(num_examples / global_batch_size))
num_devices = max(torch.cuda.device_count(), jax.local_device_count())
@@ -174,10 +177,9 @@ def _eval_model_on_split(self,
batch = next(self._eval_iters[split])
per_device_model_rngs = prng.split(model_rng, num_devices)
# We already average these metrics across devices inside _compute_metrics.
- synced_metrics = self._eval_model(params,
- batch,
- model_state,
- per_device_model_rngs)
+ synced_metrics = self._eval_model(
+ params, batch, model_state, per_device_model_rngs
+ )
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index 6d9a489ff..706c2b51a 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -3,8 +3,12 @@
from typing import Sequence
import flax.linen as nn
-from jax import nn as jnn
import jax.numpy as jnp
+from jax import nn as jnn
+
+from algoperf.jax_utils import Dropout
+
+DROPOUT_RATE = 0.0
class DLRMResNet(nn.Module):
@@ -23,12 +27,12 @@ class DLRMResNet(nn.Module):
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
embed_dim: int = 128
- dropout_rate: float = 0.0
+ dropout_rate: float = DROPOUT_RATE
use_layer_norm: bool = False # Unused.
embedding_init_multiplier: float = None # Unused
@nn.compact
- def __call__(self, x, train):
+ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -36,20 +40,18 @@ def __call__(self, x, train):
mlp_bottom_dims = self.mlp_bottom_dims
bot_mlp_input = nn.Dense(
- mlp_bottom_dims[0],
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5),
- )(
- bot_mlp_input)
+ mlp_bottom_dims[0],
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0] ** 0.5),
+ )(bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)
for dense_dim in mlp_bottom_dims[1:]:
x = nn.Dense(
- dense_dim,
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
- )(
- bot_mlp_input)
+ dense_dim,
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
+ )(bot_mlp_input)
bot_mlp_input += nn.relu(x)
base_init_fn = jnn.initializers.uniform(scale=1.0)
@@ -59,46 +61,51 @@ def __call__(self, x, train):
def scaled_init(key, shape, dtype=jnp.float_):
return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size)
- embedding_table = self.param('embedding_table',
- scaled_init, [self.vocab_size, self.embed_dim])
+ embedding_table = self.param(
+ 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim]
+ )
embed_features = embedding_table[idx_lookup]
batch_size = bot_mlp_input.shape[0]
- embed_features = jnp.reshape(embed_features,
- (batch_size, 26 * self.embed_dim))
+ embed_features = jnp.reshape(
+ embed_features, (batch_size, 26 * self.embed_dim)
+ )
top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
top_mlp_input = nn.Dense(
- mlp_top_dims[0],
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
- bias_init=jnn.initializers.normal(
- stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
- top_mlp_input)
+ mlp_top_dims[0],
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))
+ ),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / mlp_top_dims[0])),
+ )(top_mlp_input)
top_mlp_input = nn.relu(top_mlp_input)
for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
fan_in = mlp_top_dims[layer_idx - 1]
x = nn.Dense(
- fan_out,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
- bias_init=jnn.initializers.normal(
- stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
- top_mlp_input)
+ fan_out,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (fan_in + fan_out))
+ ),
+ bias_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])
+ ),
+ )(top_mlp_input)
x = nn.relu(x)
- if self.dropout_rate and layer_idx == num_layers_top - 2:
- x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
+ if dropout_rate and layer_idx == num_layers_top - 2:
+ x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
logits = nn.Dense(
- 1,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))(
- top_mlp_input)
+ 1,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))
+ ),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)),
+ )(top_mlp_input)
return logits
@@ -114,16 +121,18 @@ def dot_interact(concat_features):
batch_size = concat_features.shape[0]
# Interact features, select upper or lower-triangular portion, and reshape.
- xactions = jnp.matmul(concat_features,
- jnp.transpose(concat_features, [0, 2, 1]))
+ xactions = jnp.matmul(
+ concat_features, jnp.transpose(concat_features, [0, 2, 1])
+ )
feature_dim = xactions.shape[-1]
indices = jnp.array(jnp.triu_indices(feature_dim))
num_elems = indices.shape[1]
indices = jnp.tile(indices, [1, batch_size])
indices0 = jnp.reshape(
- jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
- [1, -1])
+ jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
+ [1, -1],
+ )
indices = tuple(jnp.concatenate((indices0, indices), 0))
activations = xactions[indices]
activations = jnp.reshape(activations, [batch_size, -1])
@@ -151,25 +160,25 @@ class DlrmSmall(nn.Module):
embedding_init_multiplier: float = None
@nn.compact
- def __call__(self, x, train):
+ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
# Bottom MLP.
for dense_dim in self.mlp_bottom_dims:
bot_mlp_input = nn.Dense(
- dense_dim,
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
- )(
- bot_mlp_input)
+ dense_dim,
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
+ )(bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)
if self.use_layer_norm:
bot_mlp_input = nn.LayerNorm()(bot_mlp_input)
bot_mlp_output = bot_mlp_input
batch_size = bot_mlp_output.shape[0]
- feature_stack = jnp.reshape(bot_mlp_output,
- [batch_size, -1, self.embed_dim])
+ feature_stack = jnp.reshape(
+ bot_mlp_output, [batch_size, -1, self.embed_dim]
+ )
# Embedding table look-up.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
@@ -182,38 +191,45 @@ def __call__(self, x, train):
def scaled_init(key, shape, dtype=jnp.float_):
return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale
- embedding_table = self.param('embedding_table',
- scaled_init, [self.vocab_size, self.embed_dim])
+ embedding_table = self.param(
+ 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim]
+ )
idx_lookup = jnp.reshape(idx_lookup, [-1])
embed_features = embedding_table[idx_lookup]
- embed_features = jnp.reshape(embed_features,
- [batch_size, -1, self.embed_dim])
+ embed_features = jnp.reshape(
+ embed_features, [batch_size, -1, self.embed_dim]
+ )
if self.use_layer_norm:
embed_features = nn.LayerNorm()(embed_features)
feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
- top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
- axis=-1)
+ top_mlp_input = jnp.concatenate(
+ [bot_mlp_output, dot_interact_output], axis=-1
+ )
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
for layer_idx, fan_out in enumerate(mlp_top_dims):
fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1]
top_mlp_input = nn.Dense(
- fan_out,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))(
- top_mlp_input)
+ fan_out,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (fan_in + fan_out))
+ ),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)),
+ )(top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
- if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- top_mlp_input = nn.Dropout(
- rate=self.dropout_rate, deterministic=not train)(
- top_mlp_input)
+ if (
+ dropout_rate is not None
+ and dropout_rate > 0.0
+ and layer_idx == num_layers_top - 2
+ ):
+ top_mlp_input = Dropout(dropout_rate, deterministic=not train)(
+ top_mlp_input, rate=dropout_rate
+ )
logits = top_mlp_input
return logits
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index 91761e458..283b3be8e 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -3,26 +3,24 @@
import functools
from typing import Dict, Optional, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
-from algoperf.workloads.criteo1tb.workload import \
- BaseCriteo1TbDlrmSmallWorkload
+from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
-
@property
def eval_batch_size(self) -> int:
return 131_072
def _per_example_sigmoid_binary_cross_entropy(
- self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
+ self, logits: spec.Tensor, targets: spec.Tensor
+ ) -> spec.Tensor:
"""Computes the sigmoid binary cross entropy per example.
Args:
@@ -39,11 +37,12 @@ def _per_example_sigmoid_binary_cross_entropy(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense (not one-hot) labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense (not one-hot) labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -55,7 +54,8 @@ def loss_fn(
label_batch = jnp.reshape(label_batch, (batch_size,))
logits_batch = jnp.reshape(logits_batch, (batch_size,))
per_example_losses = self._per_example_sigmoid_binary_cross_entropy(
- logits=logits_batch, targets=label_batch)
+ logits=logits_batch, targets=label_batch
+ )
if mask_batch is not None:
mask_batch = jnp.reshape(mask_batch, (batch_size,))
per_example_losses *= mask_batch
@@ -64,35 +64,33 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None,
- tabulate: Optional[bool] = False,
+ self,
+ rng: spec.RandomState,
+ tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
- del aux_dropout_rate
if self.use_resnet:
model_class = models.DLRMResNet
else:
model_class = models.DlrmSmall
+
self._model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- dropout_rate=dropout_rate,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
-
- params_rng, dropout_rng = jax.random.split(rng)
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier,
+ )
+
+ params_rng, _ = jax.random.split(rng)
init_fake_batch_size = 2
num_categorical_features = 26
num_dense_features = 13
@@ -100,8 +98,11 @@ def init_model_fn(
input_shape = (init_fake_batch_size, input_size)
init_fn = functools.partial(self._model.init, train=False)
initial_variables = jax.jit(init_fn)(
- {'params': params_rng, 'dropout': dropout_rng},
- jnp.ones(input_shape, jnp.float32))
+ {
+ 'params': params_rng,
+ },
+ jnp.ones(input_shape, jnp.float32),
+ )
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -111,13 +112,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_7'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
@@ -125,39 +128,43 @@ def model_fn(
apply_kwargs = {'train': train}
if train:
apply_kwargs['rngs'] = {'dropout': rng}
+ apply_kwargs['dropout_rate'] = dropout_rate
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
return logits_batch, None
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0),
- static_broadcasted_argnums=(0,))
- def _eval_batch_pmapped(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> spec.Tensor:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0),
+ static_broadcasted_argnums=(0,),
+ )
+ def _eval_batch_pmapped(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> spec.Tensor:
logits, _ = self.model_fn(
- params,
- batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
summed_loss = self.loss_fn(
- label_batch=batch['targets'], logits_batch=logits,
- mask_batch=weights)['summed']
+ label_batch=batch['targets'], logits_batch=logits, mask_batch=weights
+ )['summed']
return summed_loss
- def _eval_batch(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> spec.Tensor:
+ def _eval_batch(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return np.array(
- self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64)
+ self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64
+ )
class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
@@ -165,7 +172,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use LayerNorm in the model."""
@@ -199,7 +205,6 @@ def test_target_value(self) -> float:
class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def validation_target_value(self) -> float:
return 0.129657
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
index 7a40f0e81..1906bf7ae 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
@@ -5,9 +5,13 @@
import torch
from torch import nn
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+
+DROPOUT_RATE = 0.0
+
class DenseBlock(nn.Module):
- """Dense block with optional residual connection.""" ""
+ """Dense block with optional residual connection.""" ''
def __init__(self, module, resnet=False):
super().__init__()
@@ -15,10 +19,20 @@ def __init__(self, module, resnet=False):
self.resnet = resnet
def forward(self, x):
- if self.resnet:
- return self.module(x) + x
- else:
- return self.module(x)
+ return self.module(x) + x if self.resnet else self.module(x)
+
+
+class DenseBlockWithDropout(nn.Module):
+ """Dense block with optional residual connection and support for dropout."""
+
+ def __init__(self, module, resnet=False):
+ super().__init__()
+ self.module = module
+ self.resnet = resnet
+ self._supports_custom_dropout = True
+
+ def forward(self, x, p):
+ return self.module(x, p) + x if self.resnet else self.module(x, p)
class DotInteract(nn.Module):
@@ -26,17 +40,20 @@ class DotInteract(nn.Module):
def __init__(self, num_sparse_features):
super().__init__()
- self.triu_indices = torch.triu_indices(num_sparse_features + 1,
- num_sparse_features + 1)
+ self.triu_indices = torch.triu_indices(
+ num_sparse_features + 1, num_sparse_features + 1
+ )
def forward(self, dense_features, sparse_features):
- combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
- dim=1)
- interactions = torch.bmm(combined_values,
- torch.transpose(combined_values, 1, 2))
- interactions_flat = interactions[:,
- self.triu_indices[0],
- self.triu_indices[1]]
+ combined_values = torch.cat(
+ (dense_features.unsqueeze(1), sparse_features), dim=1
+ )
+ interactions = torch.bmm(
+ combined_values, torch.transpose(combined_values, 1, 2)
+ )
+ interactions_flat = interactions[
+ :, self.triu_indices[0], self.triu_indices[1]
+ ]
return torch.cat((dense_features, interactions_flat), dim=1)
@@ -51,16 +68,17 @@ class DLRMResNet(nn.Module):
embed_dim: embedding dimension.
"""
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(256, 256, 256),
- mlp_top_dims=(256, 256, 256, 256, 1),
- embed_dim=128,
- dropout_rate=0.0,
- use_layer_norm=False,
- embedding_init_multiplier=None):
+ def __init__(
+ self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(256, 256, 256),
+ mlp_top_dims=(256, 256, 256, 256, 1),
+ embed_dim=128,
+ use_layer_norm=False,
+ embedding_init_multiplier=None,
+ ):
super().__init__()
self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
self.num_dense_features = num_dense_features
@@ -78,7 +96,8 @@ def __init__(self,
scale = 1.0 / torch.sqrt(self.vocab_size)
for i in range(num_chunks):
chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)
+ )
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
@@ -101,11 +120,11 @@ def __init__(self,
for module in self.bot_mlp.modules():
if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
+ limit = math.sqrt(6.0 / (module.in_features + module.out_features))
nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
# Number of sparse features = 26
fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1]
@@ -116,33 +135,34 @@ def __init__(self,
block.append(nn.Linear(fan_in, fan_out))
if layer_idx < (num_layers_top - 1):
block.append(nn.ReLU(inplace=True))
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- block.append(nn.Dropout(p=dropout_rate))
- block = nn.Sequential(*block)
+ if layer_idx == num_layers_top - 2:
+ block.append(CustomDropout())
+ block = SequentialWithDropout(*block)
if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
- block = DenseBlock(block, resnet=True)
+ block = DenseBlockWithDropout(block, resnet=True)
else:
- block = DenseBlock(block)
+ block = DenseBlockWithDropout(block)
mlp_top_blocks.append(block)
fan_in = fan_out
- self.top_mlp = nn.Sequential(*mlp_top_blocks)
+ self.top_mlp = SequentialWithDropout(*mlp_top_blocks)
for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ module.weight.data,
+ 0.0,
+ math.sqrt(2.0 / (module.in_features + module.out_features)),
+ )
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
- def forward(self, x):
+ def forward(self, x, dropout_rate=DROPOUT_RATE):
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
+ x, [self.num_dense_features, self.num_sparse_features], 1
+ )
# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)
@@ -152,12 +172,13 @@ def forward(self, x):
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, 26 * self.embed_dim])
+ embedded_sparse = torch.reshape(
+ embedded_sparse, [batch_size, 26 * self.embed_dim]
+ )
top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
# Final MLP.
- logits = self.top_mlp(top_mlp_input)
+ logits = self.top_mlp(top_mlp_input, dropout_rate)
return logits
@@ -172,16 +193,17 @@ class DlrmSmall(nn.Module):
embed_dim: embedding dimension.
"""
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(512, 256, 128),
- mlp_top_dims=(1024, 1024, 512, 256, 1),
- embed_dim=128,
- dropout_rate=0.0,
- use_layer_norm=False,
- embedding_init_multiplier=None):
+ def __init__(
+ self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(512, 256, 128),
+ mlp_top_dims=(1024, 1024, 512, 256, 1),
+ embed_dim=128,
+ use_layer_norm=False,
+ embedding_init_multiplier=None,
+ ):
super().__init__()
self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
self.num_dense_features = num_dense_features
@@ -205,7 +227,8 @@ def __init__(self,
for i in range(num_chunks):
chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)
+ )
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
@@ -222,30 +245,32 @@ def __init__(self,
self.bot_mlp = nn.Sequential(*bottom_mlp_layers)
for module in self.bot_mlp.modules():
if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
+ limit = math.sqrt(6.0 / (module.in_features + module.out_features))
nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
- self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)
+ self.dot_interact = DotInteract(
+ num_sparse_features=num_sparse_features,
+ )
# TODO: Write down the formula here instead of the constant.
input_dims = 506
num_layers_top = len(self.mlp_top_dims)
top_mlp_layers = []
for layer_idx, fan_out in enumerate(self.mlp_top_dims):
- fan_in = input_dims if layer_idx == 0 \
- else self.mlp_top_dims[layer_idx - 1]
+ fan_in = (
+ input_dims if layer_idx == 0 else self.mlp_top_dims[layer_idx - 1]
+ )
top_mlp_layers.append(nn.Linear(fan_in, fan_out))
if layer_idx < (num_layers_top - 1):
top_mlp_layers.append(nn.ReLU(inplace=True))
if use_layer_norm:
top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- top_mlp_layers.append(nn.Dropout(p=dropout_rate))
- self.top_mlp = nn.Sequential(*top_mlp_layers)
+ if layer_idx == num_layers_top - 2:
+ top_mlp_layers.append(CustomDropout())
+ self.top_mlp = SequentialWithDropout(*top_mlp_layers)
if use_layer_norm:
self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
else:
@@ -253,18 +278,20 @@ def __init__(self,
for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ module.weight.data,
+ 0.0,
+ math.sqrt(2.0 / (module.in_features + module.out_features)),
+ )
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
- def forward(self, x):
+ def forward(self, x, dropout_rate=DROPOUT_RATE):
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
+ x, [self.num_dense_features, self.num_sparse_features], 1
+ )
# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)
@@ -274,14 +301,16 @@ def forward(self, x):
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, -1, self.embed_dim])
+ embedded_sparse = torch.reshape(
+ embedded_sparse, [batch_size, -1, self.embed_dim]
+ )
if self.embed_ln:
embedded_sparse = self.embed_ln(embedded_sparse)
# Dot product interactions.
concatenated_dense = self.dot_interact(
- dense_features=embedded_dense, sparse_features=embedded_sparse)
+ dense_features=embedded_dense, sparse_features=embedded_sparse
+ )
# Final MLP.
- logits = self.top_mlp(concatenated_dense)
+ logits = self.top_mlp(concatenated_dense, dropout_rate)
return logits
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index 726aa8705..74f91de43 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -7,24 +7,22 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models
-from algoperf.workloads.criteo1tb.workload import \
- BaseCriteo1TbDlrmSmallWorkload
+from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
-
@property
def eval_batch_size(self) -> int:
return 8_192
def _per_example_sigmoid_binary_cross_entropy(
- self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
+ self, logits: spec.Tensor, targets: spec.Tensor
+ ) -> spec.Tensor:
ls = torch.nn.LogSigmoid()
log_p = ls(logits)
log_not_p = ls(-logits)
@@ -35,11 +33,12 @@ def _per_example_sigmoid_binary_cross_entropy(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense (not one-hot) labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense (not one-hot) labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -51,7 +50,8 @@ def loss_fn(
label_batch = torch.reshape(label_batch, (batch_size,))
logits_batch = torch.reshape(logits_batch, (batch_size,))
per_example_losses = self._per_example_sigmoid_binary_cross_entropy(
- logits=logits_batch, targets=label_batch)
+ logits=logits_batch, targets=label_batch
+ )
if mask_batch is not None:
mask_batch = torch.reshape(mask_batch, (batch_size,))
per_example_losses *= mask_batch
@@ -60,18 +60,12 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Only dropout is used."""
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
@@ -80,14 +74,14 @@ def init_model_fn(
else:
model_class = models.DlrmSmall
model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- dropout_rate=dropout_rate,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -102,13 +96,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['top_mlp.4.weight', 'top_mlp.4.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -123,24 +119,25 @@ def model_fn(
model.train()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
- logits_batch = model(inputs)
+ logits_batch = model(inputs, dropout_rate=dropout_rate)
return logits_batch, None
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
not_train = split != 'train'
per_device_batch_size = int(global_batch_size / N_GPUS)
@@ -148,35 +145,42 @@ def _build_input_queue(
# avoid creating too many threads.
if RANK == 0:
np_iter = super()._build_input_queue(
- data_rng=data_rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=repeat_final_dataset)
+ data_rng=data_rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=repeat_final_dataset,
+ )
weights = None
while True:
if RANK == 0:
batch = next(np_iter) # pylint: disable=stop-iteration-return
inputs = torch.as_tensor(
- batch['inputs'], dtype=torch.float32, device=DEVICE)
+ batch['inputs'], dtype=torch.float32, device=DEVICE
+ )
targets = torch.as_tensor(
- batch['targets'], dtype=torch.float32, device=DEVICE)
+ batch['targets'], dtype=torch.float32, device=DEVICE
+ )
if not_train:
weights = batch.get('weights')
if weights is None:
- weights = torch.ones((N_GPUS, per_device_batch_size, 1),
- dtype=torch.float32,
- device=DEVICE)
+ weights = torch.ones(
+ (N_GPUS, per_device_batch_size, 1),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
else:
weights = torch.as_tensor(
- weights, dtype=torch.float32, device=DEVICE)
+ weights, dtype=torch.float32, device=DEVICE
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
if not_train:
# During eval, the batch size of the remainder might be different.
per_device_batch_size = torch.tensor(
- len(targets[0]), dtype=torch.int32, device=DEVICE)
+ len(targets[0]), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
dist.broadcast(weights, src=0)
weights = weights[0]
@@ -192,52 +196,57 @@ def _build_input_queue(
else:
if not_train:
# During eval, the batch size of the remainder might be different.
- per_device_batch_size = torch.empty((1,),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (1,), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
- weights = torch.empty((N_GPUS, per_device_batch_size, 1),
- dtype=torch.float32,
- device=DEVICE)
+ weights = torch.empty(
+ (N_GPUS, per_device_batch_size, 1),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
dist.broadcast(weights, src=0)
weights = weights[RANK]
- inputs = torch.empty((N_GPUS, per_device_batch_size, 39),
- dtype=torch.float32,
- device=DEVICE)
+ inputs = torch.empty(
+ (N_GPUS, per_device_batch_size, 39),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
dist.broadcast(inputs, src=0)
inputs = inputs[RANK]
- targets = torch.empty((N_GPUS, per_device_batch_size, 1),
- dtype=torch.float32,
- device=DEVICE)
+ targets = torch.empty(
+ (N_GPUS, per_device_batch_size, 1), dtype=torch.float32, device=DEVICE
+ )
dist.broadcast(targets, src=0)
targets = targets[RANK]
if weights is None:
weights = torch.ones(per_device_batch_size, device=DEVICE)
batch = {
- 'inputs': inputs,
- 'targets': targets,
- 'weights': weights,
+ 'inputs': inputs,
+ 'targets': targets,
+ 'weights': weights,
}
yield batch
- def _eval_batch(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> spec.Tensor:
+ def _eval_batch(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> spec.Tensor:
logits, _ = self.model_fn(
- params,
- batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = torch.ones(len(logits), device=DEVICE)
summed_loss = self.loss_fn(
- label_batch=batch['targets'], logits_batch=logits,
- mask_batch=weights)['summed']
+ label_batch=batch['targets'], logits_batch=logits, mask_batch=weights
+ )['summed']
return summed_loss.to(dtype=torch.float64)
@@ -246,7 +255,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use LayerNorm in the model."""
@@ -280,7 +288,6 @@ def test_target_value(self) -> float:
class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def validation_target_value(self) -> float:
return 0.129657
diff --git a/algoperf/workloads/criteo1tb/input_pipeline.py b/algoperf/workloads/criteo1tb/input_pipeline.py
index 7e254336a..bce8b11c4 100644
--- a/algoperf/workloads/criteo1tb/input_pipeline.py
+++ b/algoperf/workloads/criteo1tb/input_pipeline.py
@@ -19,32 +19,32 @@
# Raw vocab sizes from
# https://cloud.google.com/tpu/docs/tutorials/dlrm-dcn-2.x#run-model.
_VOCAB_SIZES = [
- 39884406,
- 39043,
- 17289,
- 7420,
- 20263,
- 3,
- 7120,
- 1543,
- 63,
- 38532951,
- 2953546,
- 403346,
- 10,
- 2208,
- 11938,
- 155,
- 4,
- 976,
- 14,
- 39979771,
- 25641295,
- 39664984,
- 585935,
- 12972,
- 108,
- 36,
+ 39884406,
+ 39043,
+ 17289,
+ 7420,
+ 20263,
+ 3,
+ 7120,
+ 1543,
+ 63,
+ 38532951,
+ 2953546,
+ 403346,
+ 10,
+ 2208,
+ 11938,
+ 155,
+ 4,
+ 976,
+ 14,
+ 39979771,
+ 25641295,
+ 39664984,
+ 585935,
+ 12972,
+ 108,
+ 36,
]
@@ -60,7 +60,8 @@ def _parse_example_fn(num_dense_features, example):
categorical_defaults = [['00000000'] for _ in range(len(_VOCAB_SIZES))]
record_defaults = label_defaults + int_defaults + categorical_defaults
fields = tf.io.decode_csv(
- example, record_defaults, field_delim='\t', na_value='-1')
+ example, record_defaults, field_delim='\t', na_value='-1'
+ )
num_labels = 1
features = {}
@@ -78,20 +79,24 @@ def _parse_example_fn(num_dense_features, example):
# We append the column index to the string to make the same id in different
# columns unique.
cat_features.append(
- tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx]))
+ tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx])
+ )
cat_features = tf.cast(
- tf.stack(cat_features, axis=1), dtype=int_features.dtype)
+ tf.stack(cat_features, axis=1), dtype=int_features.dtype
+ )
features['inputs'] = tf.concat([int_features, cat_features], axis=1)
return features
-def get_criteo1tb_dataset(split: str,
- shuffle_rng,
- data_dir: str,
- num_dense_features: int,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False):
+def get_criteo1tb_dataset(
+ split: str,
+ shuffle_rng,
+ data_dir: str,
+ num_dense_features: int,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+):
"""Get the Criteo 1TB dataset for a given split."""
num_test_files = _NUM_DAY_23_FILES // 2 + 1
if split in ['train', 'eval_train']:
@@ -99,19 +104,20 @@ def get_criteo1tb_dataset(split: str,
elif split == 'validation':
# Assumes files are of the format day_23_04.
file_paths = [
- os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
- for s in range(num_test_files, _NUM_DAY_23_FILES)
+ os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
+ for s in range(num_test_files, _NUM_DAY_23_FILES)
]
else:
file_paths = [
- os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
- for s in range(0, num_test_files)
+ os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
+ for s in range(0, num_test_files)
]
is_training = split == 'train'
shuffle = is_training or split == 'eval_train'
ds = tf.data.Dataset.list_files(
- file_paths, shuffle=shuffle, seed=shuffle_rng[0])
+ file_paths, shuffle=shuffle, seed=shuffle_rng[0]
+ )
if shuffle:
ds = ds.shuffle(buffer_size=1024)
@@ -132,9 +138,10 @@ def get_criteo1tb_dataset(split: str,
ds = ds.repeat()
ds = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
return ds
diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py
index 617b2e987..2cb7e5450 100644
--- a/algoperf/workloads/criteo1tb/workload.py
+++ b/algoperf/workloads/criteo1tb/workload.py
@@ -4,8 +4,8 @@
import os
from typing import Dict, Iterator, Optional, Tuple
-from absl import flags
import torch.distributed as dist
+from absl import flags
from algoperf import spec
from algoperf.workloads.criteo1tb import input_pipeline
@@ -29,8 +29,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'loss'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/loss'] < self.validation_target_value
@property
@@ -71,8 +72,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -100,23 +102,25 @@ def eval_period_time_sec(self) -> int:
return 2 * 60 # 2 mins.
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
ds = input_pipeline.get_criteo1tb_dataset(
- split=split,
- shuffle_rng=data_rng,
- data_dir=data_dir,
- num_dense_features=self.num_dense_features,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=repeat_final_dataset)
+ split=split,
+ shuffle_rng=data_rng,
+ data_dir=data_dir,
+ num_dense_features=self.num_dense_features,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=repeat_final_dataset,
+ )
for batch in iter(ds):
yield batch
@@ -126,15 +130,17 @@ def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 10_666
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -142,12 +148,13 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng=rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=True)
+ data_rng=rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=True,
+ )
loss = 0.0
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index 44bff0e21..b80c370ea 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -12,13 +12,17 @@
Data:
github.com/facebookresearch/fastMRI/tree/main/fastmri/data
"""
+
import functools
-from typing import Optional
import flax.linen as nn
import jax
import jax.numpy as jnp
+from algoperf.jax_utils import Dropout
+
+DROPOUT_RATE = 0.0
+
def _instance_norm2d(x, axes, epsilon=1e-5):
# promote x to at least float32, this avoids half precision computation
@@ -28,7 +32,7 @@ def _instance_norm2d(x, axes, epsilon=1e-5):
mean2 = jnp.mean(jnp.square(x), axes)
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
- var = jnp.maximum(0., mean2 - jnp.square(mean))
+ var = jnp.maximum(0.0, mean2 - jnp.square(mean))
stats_shape = list(x.shape)
for axis in axes:
stats_shape[axis] = 1
@@ -43,39 +47,38 @@ def _instance_norm2d(x, axes, epsilon=1e-5):
class UNet(nn.Module):
"""Jax / Flax implementation of a U-Net model.
- O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
- for biomedical image segmentation. In International Conference on Medical
- image computing and computer-assisted intervention, pages 234–241.
- Springer, 2015.
+ O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
+ for biomedical image segmentation. In International Conference on Medical
+ image computing and computer-assisted intervention, pages 234–241.
+ Springer, 2015.
- out_channels: Number of channels in the output to the U-Net model.
- channels: Number of output channels of the first convolution layer.
- num_pool_layers: Number of down-sampling and up-sampling layers.
- dropout_rate: Dropout probability.
+ out_channels: Number of channels in the output to the U-Net model.
+ channels: Number of output channels of the first convolution layer.
+ num_pool_layers: Number of down-sampling and up-sampling layers.
+ dropout_rate: Dropout probability.
"""
+
num_channels: int = 32
num_pool_layers: int = 4
out_channels = 1
- dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
+ dropout_rate: float = DROPOUT_RATE
use_tanh: bool = False
use_layer_norm: bool = False
@nn.compact
- def __call__(self, x, train=True):
- dropout_rate = self.dropout_rate
- if dropout_rate is None:
- dropout_rate = 0.0
-
+ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
# pylint: disable=invalid-name
_ConvBlock = functools.partial(
- ConvBlock,
- dropout_rate=dropout_rate,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
+ ConvBlock,
+ dropout_rate=dropout_rate,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
_TransposeConvBlock = functools.partial(
- TransposeConvBlock,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
+ TransposeConvBlock,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
down_sample_layers = [_ConvBlock(self.num_channels)]
@@ -126,9 +129,9 @@ def __call__(self, x, train=True):
output = jnp.concatenate((output, downsample_layer), axis=-1)
output = conv(output, train)
- output = nn.Conv(
- self.out_channels, kernel_size=(1, 1), strides=(1, 1))(
- output)
+ output = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))(
+ output
+ )
return output.squeeze(-1)
@@ -137,13 +140,14 @@ class ConvBlock(nn.Module):
out_channels: Number of channels in the output.
dropout_rate: Dropout probability.
"""
+
out_channels: int
- dropout_rate: float
use_tanh: bool
use_layer_norm: bool
+ dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x, train=True):
+ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
"""Forward function.
Note: Pytorch is NCHW and jax/flax is NHWC.
Args:
@@ -153,11 +157,11 @@ def __call__(self, x, train=True):
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
x = nn.Conv(
- features=self.out_channels,
- kernel_size=(3, 3),
- strides=(1, 1),
- use_bias=False)(
- x)
+ features=self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ use_bias=False,
+ )(x)
if self.use_layer_norm:
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
@@ -172,23 +176,23 @@ def __call__(self, x, train=True):
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
- x = nn.Dropout(
- self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x)
+ x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate
+ )
x = nn.Conv(
- features=self.out_channels,
- kernel_size=(3, 3),
- strides=(1, 1),
- use_bias=False)(
- x)
+ features=self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ use_bias=False,
+ )(x)
if self.use_layer_norm:
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
- x = nn.Dropout(
- self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x)
+ x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate
+ )
return x
@@ -196,6 +200,7 @@ class TransposeConvBlock(nn.Module):
"""A Transpose Convolutional Block.
out_channels: Number of channels in the output.
"""
+
out_channels: int
use_tanh: bool
use_layer_norm: bool
@@ -209,8 +214,8 @@ def __call__(self, x):
jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`.
"""
x = nn.ConvTranspose(
- self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
- x)
+ self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False
+ )(x)
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
diff --git a/algoperf/workloads/fastmri/fastmri_jax/ssim.py b/algoperf/workloads/fastmri/fastmri_jax/ssim.py
index e15b93616..ca2ee1b60 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/ssim.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/ssim.py
@@ -49,12 +49,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None):
return ssims
-def structural_similarity(im1,
- im2,
- data_range=1.0,
- win_size=7,
- k1=0.01,
- k2=0.03):
+def structural_similarity(
+ im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03
+):
"""Compute the mean structural similarity index between two images.
NOTE(dsuo): modified from skimage.metrics.structural_similarity.
@@ -85,7 +82,7 @@ def structural_similarity(im1,
"""
filter_func = functools.partial(_uniform_filter, size=win_size)
- num_points = win_size**len(im1.shape)
+ num_points = win_size ** len(im1.shape)
# filter has already normalized by num_points
cov_norm = num_points / (num_points - 1) # sample covariance
@@ -102,8 +99,8 @@ def structural_similarity(im1,
vy = cov_norm * (uyy - uy * uy)
vxy = cov_norm * (uxy - ux * uy)
- c1 = (k1 * data_range)**2
- c2 = (k2 * data_range)**2
+ c1 = (k1 * data_range) ** 2
+ c2 = (k2 * data_range) ** 2
a1 = 2 * ux * uy + c1
a2 = 2 * vxy + c2
@@ -121,12 +118,15 @@ def structural_similarity(im1,
def _uniform_filter(im, size=7):
-
def conv(im):
- return jnp.convolve(
+ return (
+ jnp.convolve(
jnp.pad(im, pad_width=size // 2, mode='symmetric'),
jnp.ones(size),
- mode='valid') / size
+ mode='valid',
+ )
+ / size
+ )
im = jax.vmap(conv, (0,))(im)
im = jax.vmap(conv, (1,))(im)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index 1156cf30a..08bb25014 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -4,38 +4,34 @@
import math
from typing import Dict, Optional, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
import algoperf.random_utils as prng
-from algoperf.workloads.fastmri.fastmri_jax.models import UNet
+from algoperf import param_utils, spec
+from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE, UNet
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload
class FastMRIWorkload(BaseFastMRIWorkload):
-
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ self,
+ rng: spec.RandomState,
+ ) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
- del aux_dropout_rate
fake_batch = jnp.zeros((13, 320, 320))
self._model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
- dropout_rate=dropout_rate)
- params_rng, dropout_rng = jax.random.split(rng)
- variables = jax.jit(
- self._model.init)({'params': params_rng, 'dropout': dropout_rng},
- fake_batch)
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
+
+ params_rng, _ = jax.random.split(rng)
+ init_fn = functools.partial(self._model.init, train=False)
+ variables = jax.jit(init_fn)({'params': params_rng}, fake_batch)
params = variables['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -46,30 +42,37 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Conv_0'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train)
+
+ logits = self._model.apply(
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate,
+ )
return logits, None
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -78,8 +81,9 @@ def loss_fn(
"""
del label_smoothing
per_example_losses = jnp.mean(
- jnp.abs(logits_batch - label_batch),
- axis=tuple(range(1, logits_batch.ndim)))
+ jnp.abs(logits_batch - label_batch),
+ axis=tuple(range(1, logits_batch.ndim)),
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -88,56 +92,63 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0),
- static_broadcasted_argnums=(0,))
- def _eval_model(self,
- params: spec.Tensor,
- batch: Dict[str, spec.Tensor],
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0),
+ static_broadcasted_argnums=(0,),
+ )
+ def _eval_model(
+ self,
+ params: spec.Tensor,
+ batch: Dict[str, spec.Tensor],
+ rng: spec.RandomState,
+ ) -> Dict[str, spec.Tensor]:
"""Return the SSIM and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
ssim_vals = ssim(
- logits,
- targets,
- mean=batch['mean'],
- std=batch['std'],
- volume_max=batch['volume_max'])
+ logits,
+ targets,
+ mean=batch['mean'],
+ std=batch['std'],
+ volume_max=batch['volume_max'],
+ )
ssim_sum = jnp.sum(ssim_vals * weights)
summed_loss = self.loss_fn(targets, logits, weights)['summed']
metrics = {
- 'ssim': ssim_sum,
- 'loss': summed_loss,
+ 'ssim': ssim_sum,
+ 'loss': summed_loss,
}
metrics = jax.lax.psum(metrics, axis_name='batch')
return metrics
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -146,27 +157,27 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split,
- data_dir,
- global_batch_size=global_batch_size,
- repeat_final_dataset=True,
- num_batches=num_batches)
-
- total_metrics = {'ssim': 0., 'loss': 0.}
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size=global_batch_size,
+ repeat_final_dataset=True,
+ num_batches=num_batches,
+ )
+
+ total_metrics = {'ssim': 0.0, 'loss': 0.0}
eval_rngs = prng.split(model_rng, jax.local_device_count())
for _ in range(num_batches):
batch = next(self._eval_iters[split])
# We already sum these metrics across devices inside _eval_model.
synced_metrics = self._eval_model(params, batch, eval_rngs)
total_metrics = {
- k: v + synced_metrics[k][0] for k, v in total_metrics.items()
+ k: v + synced_metrics[k][0] for k, v in total_metrics.items()
}
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}
class FastMRIModelSizeWorkload(FastMRIWorkload):
-
@property
def num_pool_layers(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -187,7 +198,6 @@ def test_target_value(self) -> float:
class FastMRITanhWorkload(FastMRIWorkload):
-
@property
def use_tanh(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -203,7 +213,6 @@ def test_target_value(self) -> float:
class FastMRILayerNormWorkload(FastMRIWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use tanh activations in the model."""
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
index 28f20bf20..16cf8bd54 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
@@ -5,93 +5,93 @@
"""
from functools import partial
-from typing import Optional
import torch
-from torch import nn
-from torch import Tensor
+from torch import Tensor, nn
from torch.nn import functional as F
from algoperf import init_utils
+from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
+
+DROPOUT_RATE = 0.0
class UNet(nn.Module):
r"""U-Net model from
- `"U-net: Convolutional networks
- for biomedical image segmentation"
- `_.
- """
-
- def __init__(self,
- in_chans: int = 1,
- out_chans: int = 1,
- num_channels: int = 32,
- num_pool_layers: int = 4,
- dropout_rate: Optional[float] = 0.0,
- use_tanh: bool = False,
- use_layer_norm: bool = False) -> None:
+ `"U-net: Convolutional networks
+ for biomedical image segmentation"
+ `_.
+ """
+
+ def __init__(
+ self,
+ in_chans: int = 1,
+ out_chans: int = 1,
+ num_channels: int = 32,
+ num_pool_layers: int = 4,
+ use_tanh: bool = False,
+ use_layer_norm: bool = False,
+ ) -> None:
super().__init__()
self.in_chans = in_chans
self.out_chans = out_chans
self.num_channels = num_channels
self.num_pool_layers = num_pool_layers
- if dropout_rate is None:
- dropout_rate = 0.0
- self.down_sample_layers = nn.ModuleList([
- ConvBlock(in_chans,
- num_channels,
- dropout_rate,
- use_tanh,
- use_layer_norm)
- ])
+
+ self.down_sample_layers = nn.ModuleList(
+ [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)]
+ )
ch = num_channels
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(
- ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm))
+ ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
+ )
ch *= 2
- self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)
+ self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
self.up_conv = nn.ModuleList()
self.up_transpose_conv = nn.ModuleList()
for _ in range(num_pool_layers - 1):
self.up_transpose_conv.append(
- TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
- self.up_conv.append(
- ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm))
+ TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)
+ )
+ self.up_conv.append(ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
ch //= 2
self.up_transpose_conv.append(
- TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)
+ )
self.up_conv.append(
- nn.Sequential(
- ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm),
- nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
- ))
+ SequentialWithDropout(
+ ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
+ nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
+ )
+ )
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
- def forward(self, x: Tensor) -> Tensor:
+ def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor:
stack = []
output = x
# apply down-sampling layers
for layer in self.down_sample_layers:
- output = layer(output)
+ output = layer(output, dropout_rate)
stack.append(output)
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
- output = self.conv(output)
+ output = self.conv(output, dropout_rate)
# apply up-sampling layers
for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
downsample_layer = stack.pop()
output = transpose_conv(output)
- # reflect pad on the right/botton if needed to handle
+ # reflect pad on the right/bottom if needed to handle
# odd input dimensions
padding = [0, 0, 0, 0]
if output.shape[-1] != downsample_layer.shape[-1]:
@@ -99,10 +99,10 @@ def forward(self, x: Tensor) -> Tensor:
if output.shape[-2] != downsample_layer.shape[-2]:
padding[3] = 1 # padding bottom
if torch.sum(torch.tensor(padding)) != 0:
- output = F.pad(output, padding, "reflect")
+ output = F.pad(output, padding, 'reflect')
output = torch.cat([output, downsample_layer], dim=1)
- output = conv(output)
+ output = conv(output, dropout_rate)
return output
@@ -111,13 +111,11 @@ class ConvBlock(nn.Module):
# A Convolutional Block that consists of two convolution layers each
# followed by instance normalization, LeakyReLU activation and dropout_rate.
- def __init__(self,
- in_chans: int,
- out_chans: int,
- dropout_rate: float,
- use_tanh: bool,
- use_layer_norm: bool) -> None:
+ def __init__(
+ self, in_chans: int, out_chans: int, use_tanh: bool, use_layer_norm: bool
+ ) -> None:
super().__init__()
+ self._supports_custom_dropout = True
if use_layer_norm:
norm_layer = partial(nn.GroupNorm, 1, eps=1e-6)
@@ -127,19 +125,19 @@ def __init__(self,
activation_fn = nn.Tanh()
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- self.conv_layers = nn.Sequential(
- nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
- norm_layer(out_chans),
- activation_fn,
- nn.Dropout2d(dropout_rate),
- nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
- norm_layer(out_chans),
- activation_fn,
- nn.Dropout2d(dropout_rate),
+ self.conv_layers = SequentialWithDropout(
+ nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
+ norm_layer(out_chans),
+ activation_fn,
+ CustomDropout2d(),
+ nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
+ norm_layer(out_chans),
+ activation_fn,
+ CustomDropout2d(),
)
- def forward(self, x: Tensor) -> Tensor:
- return self.conv_layers(x)
+ def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
+ return self.conv_layers(x, dropout_rate)
class TransposeConvBlock(nn.Module):
@@ -147,11 +145,11 @@ class TransposeConvBlock(nn.Module):
# layers followed by instance normalization and LeakyReLU activation.
def __init__(
- self,
- in_chans: int,
- out_chans: int,
- use_tanh: bool,
- use_layer_norm: bool,
+ self,
+ in_chans: int,
+ out_chans: int,
+ use_tanh: bool,
+ use_layer_norm: bool,
):
super().__init__()
if use_tanh:
@@ -159,10 +157,11 @@ def __init__(
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layers = nn.Sequential(
- nn.ConvTranspose2d(
- in_chans, out_chans, kernel_size=2, stride=2, bias=False),
- nn.InstanceNorm2d(out_chans),
- activation_fn,
+ nn.ConvTranspose2d(
+ in_chans, out_chans, kernel_size=2, stride=2, bias=False
+ ),
+ nn.InstanceNorm2d(out_chans),
+ activation_fn,
)
def forward(self, x: Tensor) -> Tensor:
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py
index 45b61bea4..7d594b959 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py
@@ -32,9 +32,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None):
# NOTE(dsuo): `volume_max` can be 0 if we have a padded batch, but this will
# lead to NaN values in `ssim`.
- volume_max = torch.where(volume_max == 0,
- torch.ones_like(volume_max),
- volume_max)
+ volume_max = torch.where(
+ volume_max == 0, torch.ones_like(volume_max), volume_max
+ )
if mean is None:
mean = torch.zeros(logits.shape[0], device=DEVICE)
@@ -56,12 +56,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None):
return ssims
-def structural_similarity(im1,
- im2,
- data_range=1.0,
- win_size=7,
- k1=0.01,
- k2=0.03):
+def structural_similarity(
+ im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03
+):
"""Compute the mean structural similarity index between two images.
NOTE(dsuo): modified from skimage.metrics.structural_similarity.
@@ -92,7 +89,7 @@ def structural_similarity(im1,
"""
filter_func = functools.partial(_uniform_filter, size=win_size)
- num_points = win_size**len(im1.shape)
+ num_points = win_size ** len(im1.shape)
# filter has already normalized by num_points
cov_norm = num_points / (num_points - 1) # sample covariance
@@ -109,8 +106,8 @@ def structural_similarity(im1,
vy = cov_norm * (uyy - uy * uy)
vxy = cov_norm * (uxy - ux * uy)
- c1 = (k1 * data_range)**2
- c2 = (k2 * data_range)**2
+ c1 = (k1 * data_range) ** 2
+ c2 = (k2 * data_range) ** 2
a1 = 2 * ux * uy + c1
a2 = 2 * vxy + c2
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 58943de2f..bddf6b1f3 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -9,10 +9,9 @@
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
import algoperf.random_utils as prng
+from algoperf import param_utils, pytorch_utils, spec
+from algoperf.workloads.fastmri.fastmri_pytorch import models
from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet
from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload
@@ -21,28 +20,31 @@
class FastMRIWorkload(BaseFastMRIWorkload):
-
- def _build_input_queue(self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None):
+ def _build_input_queue(
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ):
per_device_batch_size = int(global_batch_size / N_GPUS)
# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
data_rng = data_rng.astype('uint32')
- np_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset,
- num_batches)
+ np_iter = super()._build_input_queue(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ cache,
+ repeat_final_dataset,
+ num_batches,
+ )
while True:
if RANK == 0:
@@ -58,20 +60,23 @@ def _build_input_queue(self,
else:
aux_tensor_list.append(tensor)
batch[key] = (
- tensor[0] if USE_PYTORCH_DDP else tensor.view(
- -1, *value.shape[2:]))
+ tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, *value.shape[2:])
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
if split != 'train':
# During eval, the batch size of the remainder might be different.
per_device_batch_size = torch.tensor(
- len(batch['inputs']), dtype=torch.int32, device=DEVICE)
+ len(batch['inputs']), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
weights = weights if 'weights' in batch else None
if weights is None:
- weights = torch.ones((N_GPUS, per_device_batch_size),
- dtype=torch.float64,
- device=DEVICE)
+ weights = torch.ones(
+ (N_GPUS, per_device_batch_size),
+ dtype=torch.float64,
+ device=DEVICE,
+ )
# Has no effect, but without it `batch` has no `weights` key
# for RANK == 0, but has one for all others.
batch['weights'] = weights[0]
@@ -82,20 +87,22 @@ def _build_input_queue(self,
batch = {}
if split != 'train':
# During eval, the batch size of the remainder might be different.
- per_device_batch_size = torch.empty((),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
- weights = torch.empty((N_GPUS, per_device_batch_size),
- dtype=torch.float64,
- device=DEVICE)
+ weights = torch.empty(
+ (N_GPUS, per_device_batch_size), dtype=torch.float64, device=DEVICE
+ )
dist.broadcast(weights, src=0)
batch['weights'] = weights[RANK]
- tensors = torch.empty((2, N_GPUS, per_device_batch_size, 320, 320),
- device=DEVICE)
+ tensors = torch.empty(
+ (2, N_GPUS, per_device_batch_size, 320, 320), device=DEVICE
+ )
dist.broadcast(tensors, src=0)
- aux_tensors = torch.empty((3, N_GPUS, per_device_batch_size),
- device=DEVICE)
+ aux_tensors = torch.empty(
+ (3, N_GPUS, per_device_batch_size), device=DEVICE
+ )
dist.broadcast(aux_tensors, src=0)
# Note that the batch dict in the RANK == 0 process is ordered.
batch['inputs'] = tensors[0][RANK]
@@ -105,19 +112,14 @@ def _build_input_queue(self,
batch['volume_max'] = aux_tensors[2][RANK]
yield batch
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
- dropout_rate=dropout_rate)
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -132,13 +134,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['up_conv.3.1.weight', 'up_conv.3.1.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -152,25 +156,27 @@ def model_fn(
model.train()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logit_batch = model(
- augmented_and_preprocessed_input_batch['inputs'].unsqueeze(
- 1)).squeeze(1)
+ augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1),
+ dropout_rate=dropout_rate,
+ ).squeeze(1)
return logit_batch, None
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -179,7 +185,8 @@ def loss_fn(
"""
del label_smoothing
per_example_losses = F.l1_loss(
- logits_batch, label_batch, reduction='none').mean(dim=(1, 2))
+ logits_batch, label_batch, reduction='none'
+ ).mean(dim=(1, 2))
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -188,46 +195,52 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
- def _eval_model(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+ def _eval_model(
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ rng: spec.RandomState,
+ ) -> Dict[str, spec.Tensor]:
"""Return the SSIM and loss as a dict."""
outputs, _ = self.model_fn(
- params,
- batch,
- None,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ None,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
weights = torch.ones(len(outputs), device=DEVICE)
weights_sum = weights.sum().to(torch.int)
ssim_sum = ssim(
- outputs[:weights_sum],
- targets[:weights_sum],
- mean=batch['mean'][:weights_sum],
- std=batch['std'][:weights_sum],
- volume_max=batch['volume_max'][:weights_sum]).sum()
+ outputs[:weights_sum],
+ targets[:weights_sum],
+ mean=batch['mean'][:weights_sum],
+ std=batch['std'][:weights_sum],
+ volume_max=batch['volume_max'][:weights_sum],
+ ).sum()
summed_loss = self.loss_fn(targets, outputs, weights)['summed']
return {'ssim': ssim_sum, 'loss': summed_loss}
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -236,22 +249,23 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split,
- data_dir,
- global_batch_size=global_batch_size,
- repeat_final_dataset=True,
- num_batches=num_batches)
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size=global_batch_size,
+ repeat_final_dataset=True,
+ num_batches=num_batches,
+ )
total_metrics = {
- 'ssim': torch.tensor(0., device=DEVICE),
- 'loss': torch.tensor(0., device=DEVICE),
+ 'ssim': torch.tensor(0.0, device=DEVICE),
+ 'loss': torch.tensor(0.0, device=DEVICE),
}
for _ in range(num_batches):
batch = next(self._eval_iters[split])
batch_metrics = self._eval_model(params, batch, model_rng)
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
@@ -260,7 +274,6 @@ def _eval_model_on_split(self,
class FastMRIModelSizeWorkload(FastMRIWorkload):
-
@property
def num_pool_layers(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -281,7 +294,6 @@ def test_target_value(self) -> float:
class FastMRITanhWorkload(FastMRIWorkload):
-
@property
def use_tanh(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -297,7 +309,6 @@ def test_target_value(self) -> float:
class FastMRILayerNormWorkload(FastMRIWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use tanh activations in the model."""
diff --git a/algoperf/workloads/fastmri/input_pipeline.py b/algoperf/workloads/fastmri/input_pipeline.py
index f20611f43..62b3219c5 100644
--- a/algoperf/workloads/fastmri/input_pipeline.py
+++ b/algoperf/workloads/fastmri/input_pipeline.py
@@ -16,12 +16,9 @@
_EVAL_SEED = 0
-def _process_example(kspace,
- kspace_shape,
- target,
- target_shape,
- volume_max,
- seed):
+def _process_example(
+ kspace, kspace_shape, target, target_shape, volume_max, seed
+):
"""Generate a single example (slice from mri image).
Args:
@@ -45,15 +42,17 @@ def _process_example(kspace,
acceleration = tf.convert_to_tensor(4.0, dtype=tf.float32)
num_low_frequencies = tf.cast(
- num_cols_float * center_fraction, dtype=tf.int32)
+ num_cols_float * center_fraction, dtype=tf.int32
+ )
# calculate_center_mask
mask = tf.zeros(num_cols, dtype=tf.float32)
pad = (num_cols - num_low_frequencies + 1) // 2
mask = tf.tensor_scatter_nd_update(
- mask,
- tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)),
- tf.ones(num_low_frequencies))
+ mask,
+ tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)),
+ tf.ones(num_low_frequencies),
+ )
# reshape_mask
center_mask = tf.reshape(mask, (1, num_cols))
@@ -61,10 +60,12 @@ def _process_example(kspace,
# calculate_acceleration_mask
num_low_frequencies_float = tf.cast(num_low_frequencies, dtype=tf.float32)
prob = (num_cols_float / acceleration - num_low_frequencies_float) / (
- num_cols_float - num_low_frequencies_float)
+ num_cols_float - num_low_frequencies_float
+ )
mask = tf.cast(
- tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32)
+ tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32
+ )
acceleration_mask = tf.reshape(mask, (1, num_cols))
mask = tf.math.maximum(center_mask, acceleration_mask)
@@ -78,9 +79,11 @@ def _process_example(kspace,
shifted_image = tf.signal.ifft2d(shifted_kspace)
image = tf.signal.fftshift(shifted_image, axes=(0, 1))
scaling_norm = tf.cast(
- tf.math.sqrt(
- tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')),
- kspace.dtype)
+ tf.math.sqrt(
+ tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')
+ ),
+ kspace.dtype,
+ )
image = image * scaling_norm
image = tf.stack((tf.math.real(image), tf.math.imag(image)), axis=-1)
@@ -108,48 +111,58 @@ def _process_example(kspace,
target = tf.clip_by_value(norm_target, -6, 6)
return {
- 'inputs': image,
- 'targets': target,
- 'mean': mean,
- 'std': std,
- 'volume_max': volume_max,
+ 'inputs': image,
+ 'targets': target,
+ 'mean': mean,
+ 'std': std,
+ 'volume_max': volume_max,
}
def _h5_to_examples(path, log=False):
"""Yield MRI slices from an hdf5 file containing a single MRI volume."""
if log:
- tf.print('fastmri_dataset._h5_to_examples call:',
- path,
- datetime.datetime.now().strftime('%H:%M:%S:%f'))
+ tf.print(
+ 'fastmri_dataset._h5_to_examples call:',
+ path,
+ datetime.datetime.now().strftime('%H:%M:%S:%f'),
+ )
with open(path, 'rb') as gf:
with h5py.File(gf, 'r') as hf:
# NOTE(dsuo): logic taken from reference code
volume_max = hf.attrs.get('max', 0.0)
for i in range(hf['kspace'].shape[0]):
- yield hf['kspace'][i], hf['kspace'][i].shape, hf['reconstruction_esc'][
- i], hf['reconstruction_esc'][i].shape, volume_max
+ yield (
+ hf['kspace'][i],
+ hf['kspace'][i].shape,
+ hf['reconstruction_esc'][i],
+ hf['reconstruction_esc'][i].shape,
+ volume_max,
+ )
def _create_generator(filename):
signature = (
- tf.TensorSpec(shape=(640, None), dtype=tf.complex64),
- tf.TensorSpec(shape=(2,), dtype=tf.int32),
- tf.TensorSpec(shape=(320, 320), dtype=tf.float32),
- tf.TensorSpec(shape=(2,), dtype=tf.int32),
- tf.TensorSpec(shape=(), dtype=tf.float32),
+ tf.TensorSpec(shape=(640, None), dtype=tf.complex64),
+ tf.TensorSpec(shape=(2,), dtype=tf.int32),
+ tf.TensorSpec(shape=(320, 320), dtype=tf.float32),
+ tf.TensorSpec(shape=(2,), dtype=tf.int32),
+ tf.TensorSpec(shape=(), dtype=tf.float32),
)
return tf.data.Dataset.from_generator(
- _h5_to_examples, args=(filename,), output_signature=signature)
+ _h5_to_examples, args=(filename,), output_signature=signature
+ )
-def load_fastmri_split(global_batch_size,
- split,
- data_dir,
- shuffle_rng,
- num_batches,
- repeat_final_eval_dataset):
+def load_fastmri_split(
+ global_batch_size,
+ split,
+ data_dir,
+ shuffle_rng,
+ num_batches,
+ repeat_final_eval_dataset,
+):
"""Creates a split from the FastMRI dataset using tf.data.
NOTE: only creates knee singlecoil datasets.
@@ -169,11 +182,13 @@ def load_fastmri_split(global_batch_size,
# Check if data directories exist because glob will not raise an error
if not os.path.exists(os.path.join(data_dir, _TRAIN_DIR)):
- raise NotADirectoryError('Directory not found: {}'.format(
- os.path.join(data_dir, _TRAIN_DIR)))
+ raise NotADirectoryError(
+ 'Directory not found: {}'.format(os.path.join(data_dir, _TRAIN_DIR))
+ )
if not os.path.exists(os.path.join(data_dir, _VAL_DIR)):
- raise NotADirectoryError('Directory not found: {}'.format(
- os.path.join(data_dir, _VAL_DIR)))
+ raise NotADirectoryError(
+ 'Directory not found: {}'.format(os.path.join(data_dir, _VAL_DIR))
+ )
if split in ['train', 'eval_train']:
file_pattern = os.path.join(data_dir, _TRAIN_DIR, '*.h5')
@@ -190,10 +205,8 @@ def load_fastmri_split(global_batch_size,
shuffle = is_train or split == 'eval_train'
ds = tf.data.Dataset.from_tensor_slices(h5_paths)
ds = ds.interleave(
- _create_generator,
- cycle_length=32,
- block_length=64,
- num_parallel_calls=16)
+ _create_generator, cycle_length=32, block_length=64, num_parallel_calls=16
+ )
if is_train:
ds = ds.cache()
@@ -201,7 +214,8 @@ def process_example(example_index, example):
if shuffle:
process_rng = tf.cast(jax.random.fold_in(shuffle_rng, 0), tf.int64)
process_rng = tf.random.experimental.stateless_fold_in(
- process_rng, example_index)
+ process_rng, example_index
+ )
else:
# NOTE(dsuo): we use fixed randomness for eval.
process_rng = tf.cast(jax.random.PRNGKey(_EVAL_SEED), tf.int64)
@@ -211,9 +225,8 @@ def process_example(example_index, example):
if shuffle:
ds = ds.shuffle(
- 16 * global_batch_size,
- seed=shuffle_rng[0],
- reshuffle_each_iteration=True)
+ 16 * global_batch_size, seed=shuffle_rng[0], reshuffle_each_iteration=True
+ )
if is_train:
ds = ds.repeat()
@@ -231,7 +244,8 @@ def process_example(example_index, example):
ds = ds.repeat()
ds = ds.prefetch(10)
return map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py
index 051749cc3..0b1ecfaa1 100644
--- a/algoperf/workloads/fastmri/workload.py
+++ b/algoperf/workloads/fastmri/workload.py
@@ -8,7 +8,6 @@
class BaseFastMRIWorkload(spec.Workload):
-
@property
def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
@@ -61,8 +60,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -106,18 +106,22 @@ def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 18_094
- def _build_input_queue(self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None):
+ def _build_input_queue(
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ):
del cache
- return input_pipeline.load_fastmri_split(global_batch_size,
- split,
- data_dir,
- data_rng,
- num_batches,
- repeat_final_dataset)
+ return input_pipeline.load_fastmri_split(
+ global_batch_size,
+ split,
+ data_dir,
+ data_rng,
+ num_batches,
+ repeat_final_dataset,
+ )
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py
index 3d6939218..53368b384 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py
@@ -1,5 +1,5 @@
"""
-Note:
+Note:
The following code is adapted from:
https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image
@@ -12,35 +12,39 @@
import tensorflow as tf
_IMAGE_DTYPES = {
- tf.dtypes.uint8,
- tf.dtypes.int32,
- tf.dtypes.int64,
- tf.dtypes.float16,
- tf.dtypes.float32,
- tf.dtypes.float64,
+ tf.dtypes.uint8,
+ tf.dtypes.int32,
+ tf.dtypes.int64,
+ tf.dtypes.float16,
+ tf.dtypes.float32,
+ tf.dtypes.float64,
}
-Number = Union[float,
- int,
- np.float16,
- np.float32,
- np.float64,
- np.int8,
- np.int16,
- np.int32,
- np.int64,
- np.uint8,
- np.uint16,
- np.uint32,
- np.uint64,]
-
-TensorLike = Union[List[Union[Number, list]],
- tuple,
- Number,
- np.ndarray,
- tf.Tensor,
- tf.SparseTensor,
- tf.Variable,]
+Number = Union[
+ float,
+ int,
+ np.float16,
+ np.float32,
+ np.float64,
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.uint8,
+ np.uint16,
+ np.uint32,
+ np.uint64,
+]
+
+TensorLike = Union[
+ List[Union[Number, list]],
+ tuple,
+ Number,
+ np.ndarray,
+ tf.Tensor,
+ tf.SparseTensor,
+ tf.Variable,
+]
def get_ndims(image):
@@ -50,16 +54,19 @@ def get_ndims(image):
def to_4d_image(image):
"""Convert 2/3/4D image to 4D image.
- Args:
- image: 2/3/4D `Tensor`.
+ Args:
+ image: 2/3/4D `Tensor`.
- Returns:
- 4D `Tensor` with the same type.
- """
- with tf.control_dependencies([
+ Returns:
+ 4D `Tensor` with the same type.
+ """
+ with tf.control_dependencies(
+ [
tf.debugging.assert_rank_in(
- image, [2, 3, 4], message="`image` must be 2/3/4D tensor")
- ]):
+ image, [2, 3, 4], message='`image` must be 2/3/4D tensor'
+ )
+ ]
+ ):
ndims = image.get_shape().ndims
if ndims is None:
return _dynamic_to_4d_image(image)
@@ -80,12 +87,12 @@ def _dynamic_to_4d_image(image):
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
new_shape = tf.concat(
- [
- tf.ones(shape=left_pad, dtype=tf.int32),
- shape,
- tf.ones(shape=right_pad, dtype=tf.int32),
- ],
- axis=0,
+ [
+ tf.ones(shape=left_pad, dtype=tf.int32),
+ shape,
+ tf.ones(shape=right_pad, dtype=tf.int32),
+ ],
+ axis=0,
)
return tf.reshape(image, new_shape)
@@ -93,16 +100,16 @@ def _dynamic_to_4d_image(image):
def from_4d_image(image, ndims):
"""Convert back to an image with `ndims` rank.
- Args:
- image: 4D `Tensor`.
- ndims: The original rank of the image.
+ Args:
+ image: 4D `Tensor`.
+ ndims: The original rank of the image.
- Returns:
- `ndims`-D `Tensor` with the same type.
- """
+ Returns:
+ `ndims`-D `Tensor` with the same type.
+ """
with tf.control_dependencies(
- [tf.debugging.assert_rank(image, 4,
- message="`image` must be 4D tensor")]):
+ [tf.debugging.assert_rank(image, 4, message='`image` must be 4D tensor')]
+ ):
if isinstance(ndims, tf.Tensor):
return _dynamic_from_4d_image(image, ndims)
elif ndims == 2:
@@ -125,63 +132,64 @@ def _dynamic_from_4d_image(image, original_rank):
def transform(
- images: TensorLike,
- transforms: TensorLike,
- interpolation: str = "nearest",
- fill_mode: str = "constant",
- output_shape: Optional[list] = None,
- name: Optional[str] = None,
- fill_value: TensorLike = 0.0,
+ images: TensorLike,
+ transforms: TensorLike,
+ interpolation: str = 'nearest',
+ fill_mode: str = 'constant',
+ output_shape: Optional[list] = None,
+ name: Optional[str] = None,
+ fill_value: TensorLike = 0.0,
) -> tf.Tensor:
"""Applies the given transform(s) to the image(s).
- Args:
- images: A tensor of shape (num_images, num_rows, num_columns,
- num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or
- (num_rows, num_columns) (HW).
- transforms: Projective transform matrix/matrices. A vector of length 8 or
- tensor of size N x 8. If one row of transforms is
- [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
- `(x, y)` to a transformed *input* point
- `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
- where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
- the transform mapping input points to output points. Note that
- gradients are not backpropagated into transformation parameters.
- interpolation: Interpolation mode.
- Supported values: "nearest", "bilinear".
- fill_mode: Points outside the boundaries of the input are filled according
- to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
- - *reflect*: `(d c b a | a b c d | d c b a)`
- The input is extended by reflecting about the edge of the last pixel.
- - *constant*: `(k k k k | a b c d | k k k k)`
- The input is extended by filling all values beyond the edge with the
- same constant value k = 0.
- - *wrap*: `(a b c d | a b c d | a b c d)`
- The input is extended by wrapping around to the opposite edge.
- - *nearest*: `(a a a a | a b c d | d d d d)`
- The input is extended by the nearest pixel.
- fill_value: a float represents the value to be filled outside the
- boundaries when `fill_mode` is "constant".
- output_shape: Output dimesion after the transform, [height, width].
- If None, output is the same size as input image.
-
- name: The name of the op.
-
- Returns:
- Image(s) with the same type and shape as `images`, with the given
- transform(s) applied. Transformed coordinates outside of the input image
- will be filled with zeros.
-
- Raises:
- TypeError: If `image` is an invalid type.
- ValueError: If output shape is not 1-D int32 Tensor.
- """
- with tf.name_scope(name or "transform"):
- image_or_images = tf.convert_to_tensor(images, name="images")
+ Args:
+ images: A tensor of shape (num_images, num_rows, num_columns,
+ num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or
+ (num_rows, num_columns) (HW).
+ transforms: Projective transform matrix/matrices. A vector of length 8 or
+ tensor of size N x 8. If one row of transforms is
+ [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
+ `(x, y)` to a transformed *input* point
+ `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
+ where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
+ the transform mapping input points to output points. Note that
+ gradients are not backpropagated into transformation parameters.
+ interpolation: Interpolation mode.
+ Supported values: "nearest", "bilinear".
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
+ - *reflect*: `(d c b a | a b c d | d c b a)`
+ The input is extended by reflecting about the edge of the last pixel.
+ - *constant*: `(k k k k | a b c d | k k k k)`
+ The input is extended by filling all values beyond the edge with the
+ same constant value k = 0.
+ - *wrap*: `(a b c d | a b c d | a b c d)`
+ The input is extended by wrapping around to the opposite edge.
+ - *nearest*: `(a a a a | a b c d | d d d d)`
+ The input is extended by the nearest pixel.
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode` is "constant".
+ output_shape: Output dimesion after the transform, [height, width].
+ If None, output is the same size as input image.
+
+ name: The name of the op.
+
+ Returns:
+ Image(s) with the same type and shape as `images`, with the given
+ transform(s) applied. Transformed coordinates outside of the input image
+ will be filled with zeros.
+
+ Raises:
+ TypeError: If `image` is an invalid type.
+ ValueError: If output shape is not 1-D int32 Tensor.
+ """
+ with tf.name_scope(name or 'transform'):
+ image_or_images = tf.convert_to_tensor(images, name='images')
transform_or_transforms = tf.convert_to_tensor(
- transforms, name="transforms", dtype=tf.dtypes.float32)
+ transforms, name='transforms', dtype=tf.dtypes.float32
+ )
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
- raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+ raise TypeError('Invalid dtype %s.' % image_or_images.dtype)
images = to_4d_image(image_or_images)
original_ndims = get_ndims(image_or_images)
@@ -189,61 +197,67 @@ def transform(
output_shape = tf.shape(images)[1:3]
output_shape = tf.convert_to_tensor(
- output_shape, tf.dtypes.int32, name="output_shape")
+ output_shape, tf.dtypes.int32, name='output_shape'
+ )
if not output_shape.get_shape().is_compatible_with([2]):
- raise ValueError("output_shape must be a 1-D Tensor of 2 elements: "
- "new_height, new_width")
+ raise ValueError(
+ 'output_shape must be a 1-D Tensor of 2 elements: new_height, new_width'
+ )
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
- raise ValueError("transforms rank must be statically known")
+ raise ValueError('transforms rank must be statically known')
elif len(transform_or_transforms.get_shape()) == 2:
transforms = transform_or_transforms
else:
transforms = transform_or_transforms
- raise ValueError("transforms should have rank 1 or 2, but got rank %d" %
- len(transforms.get_shape()))
+ raise ValueError(
+ 'transforms should have rank 1 or 2, but got rank %d'
+ % len(transforms.get_shape())
+ )
fill_value = tf.convert_to_tensor(
- fill_value, dtype=tf.float32, name="fill_value")
+ fill_value, dtype=tf.float32, name='fill_value'
+ )
output = tf.raw_ops.ImageProjectiveTransformV3(
- images=images,
- transforms=transforms,
- output_shape=output_shape,
- interpolation=interpolation.upper(),
- fill_mode=fill_mode.upper(),
- fill_value=fill_value,
+ images=images,
+ transforms=transforms,
+ output_shape=output_shape,
+ interpolation=interpolation.upper(),
+ fill_mode=fill_mode.upper(),
+ fill_value=fill_value,
)
return from_4d_image(output, original_ndims)
def angles_to_projective_transforms(
- angles: TensorLike,
- image_height: TensorLike,
- image_width: TensorLike,
- name: Optional[str] = None,
+ angles: TensorLike,
+ image_height: TensorLike,
+ image_width: TensorLike,
+ name: Optional[str] = None,
) -> tf.Tensor:
"""Returns projective transform(s) for the given angle(s).
- Args:
- angles: A scalar angle to rotate all images by, or (for batches of
- images) a vector with an angle to rotate each image in the batch. The
- rank must be statically known (the shape is not `TensorShape(None)`.
- image_height: Height of the image(s) to be transformed.
- image_width: Width of the image(s) to be transformed.
-
- Returns:
- A tensor of shape (num_images, 8). Projective transforms which can be
- given to `transform` op.
- """
- with tf.name_scope(name or "angles_to_projective_transforms"):
+ Args:
+ angles: A scalar angle to rotate all images by, or (for batches of
+ images) a vector with an angle to rotate each image in the batch. The
+ rank must be statically known (the shape is not `TensorShape(None)`.
+ image_height: Height of the image(s) to be transformed.
+ image_width: Width of the image(s) to be transformed.
+
+ Returns:
+ A tensor of shape (num_images, 8). Projective transforms which can be
+ given to `transform` op.
+ """
+ with tf.name_scope(name or 'angles_to_projective_transforms'):
angle_or_angles = tf.convert_to_tensor(
- angles, name="angles", dtype=tf.dtypes.float32)
+ angles, name='angles', dtype=tf.dtypes.float32
+ )
if len(angle_or_angles.get_shape()) not in (0, 1):
- raise ValueError("angles should have rank 0 or 1.")
+ raise ValueError('angles should have rank 0 or 1.')
if len(angle_or_angles.get_shape()) == 0:
angles = angle_or_angles[None]
@@ -252,112 +266,116 @@ def angles_to_projective_transforms(
cos_angles = tf.math.cos(angles)
sin_angles = tf.math.sin(angles)
- x_offset = ((image_width - 1) -
- (cos_angles * (image_width - 1) - sin_angles *
- (image_height - 1))) / 2.0
- y_offset = ((image_height - 1) -
- (sin_angles * (image_width - 1) + cos_angles *
- (image_height - 1))) / 2.0
+ x_offset = (
+ (image_width - 1)
+ - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1))
+ ) / 2.0
+ y_offset = (
+ (image_height - 1)
+ - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1))
+ ) / 2.0
num_angles = tf.shape(angles)[0]
return tf.concat(
- values=[
- cos_angles[:, None],
- -sin_angles[:, None],
- x_offset[:, None],
- sin_angles[:, None],
- cos_angles[:, None],
- y_offset[:, None],
- tf.zeros((num_angles, 2), tf.dtypes.float32),
- ],
- axis=1,
+ values=[
+ cos_angles[:, None],
+ -sin_angles[:, None],
+ x_offset[:, None],
+ sin_angles[:, None],
+ cos_angles[:, None],
+ y_offset[:, None],
+ tf.zeros((num_angles, 2), tf.dtypes.float32),
+ ],
+ axis=1,
)
def rotate_img(
- images: TensorLike,
- angles: TensorLike,
- interpolation: str = "nearest",
- fill_mode: str = "constant",
- name: Optional[str] = None,
- fill_value: TensorLike = 0.0,
+ images: TensorLike,
+ angles: TensorLike,
+ interpolation: str = 'nearest',
+ fill_mode: str = 'constant',
+ name: Optional[str] = None,
+ fill_value: TensorLike = 0.0,
) -> tf.Tensor:
"""Rotate image(s) counterclockwise by the passed angle(s) in radians.
- Args:
- images: A tensor of shape
- `(num_images, num_rows, num_columns, num_channels)`
- (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or
- `(num_rows, num_columns)` (HW).
- angles: A scalar angle to rotate all images by (if `images` has rank 4)
- a vector of length num_images, with an angle for each image in the
- batch.
- interpolation: Interpolation mode. Supported values: "nearest",
- "bilinear".
- fill_mode: Points outside the boundaries of the input are filled according
- to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
- - *reflect*: `(d c b a | a b c d | d c b a)`
- The input is extended by reflecting about the edge of the last pixel.
- - *constant*: `(k k k k | a b c d | k k k k)`
- The input is extended by filling all values beyond the edge with the
- same constant value k = 0.
- - *wrap*: `(a b c d | a b c d | a b c d)`
- The input is extended by wrapping around to the opposite edge.
- - *nearest*: `(a a a a | a b c d | d d d d)`
- The input is extended by the nearest pixel.
- fill_value: a float represents the value to be filled outside the
- boundaries when `fill_mode` is "constant".
- name: The name of the op.
-
- Returns:
- Image(s) with the same type and shape as `images`, rotated by the given
- angle(s). Empty space due to the rotation will be filled with zeros.
-
- Raises:
- TypeError: If `images` is an invalid type.
- """
- with tf.name_scope(name or "rotate"):
+ Args:
+ images: A tensor of shape
+ `(num_images, num_rows, num_columns, num_channels)`
+ (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or
+ `(num_rows, num_columns)` (HW).
+ angles: A scalar angle to rotate all images by (if `images` has rank 4)
+ a vector of length num_images, with an angle for each image in the
+ batch.
+ interpolation: Interpolation mode. Supported values: "nearest",
+ "bilinear".
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
+ - *reflect*: `(d c b a | a b c d | d c b a)`
+ The input is extended by reflecting about the edge of the last pixel.
+ - *constant*: `(k k k k | a b c d | k k k k)`
+ The input is extended by filling all values beyond the edge with the
+ same constant value k = 0.
+ - *wrap*: `(a b c d | a b c d | a b c d)`
+ The input is extended by wrapping around to the opposite edge.
+ - *nearest*: `(a a a a | a b c d | d d d d)`
+ The input is extended by the nearest pixel.
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode` is "constant".
+ name: The name of the op.
+
+ Returns:
+ Image(s) with the same type and shape as `images`, rotated by the given
+ angle(s). Empty space due to the rotation will be filled with zeros.
+
+ Raises:
+ TypeError: If `images` is an invalid type.
+ """
+ with tf.name_scope(name or 'rotate'):
image_or_images = tf.convert_to_tensor(images)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
- raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+ raise TypeError('Invalid dtype %s.' % image_or_images.dtype)
images = to_4d_image(image_or_images)
original_ndims = get_ndims(image_or_images)
image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None]
image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None]
output = transform(
- images,
- angles_to_projective_transforms(angles, image_height, image_width),
- interpolation=interpolation,
- fill_mode=fill_mode,
- fill_value=fill_value,
+ images,
+ angles_to_projective_transforms(angles, image_height, image_width),
+ interpolation=interpolation,
+ fill_mode=fill_mode,
+ fill_value=fill_value,
)
return from_4d_image(output, original_ndims)
-def translations_to_projective_transforms(translations: TensorLike,
- name: Optional[str] = None
- ) -> tf.Tensor:
+def translations_to_projective_transforms(
+ translations: TensorLike, name: Optional[str] = None
+) -> tf.Tensor:
"""Returns projective transform(s) for the given translation(s).
- Args:
- translations: A 2-element list representing `[dx, dy]` or a matrix of
- 2-element lists representing `[dx, dy]` to translate for each image
- (for a batch of images). The rank must be statically known
- (the shape is not `TensorShape(None)`).
- name: The name of the op.
- Returns:
- A tensor of shape `(num_images, 8)` projective transforms which can be
- given to `tfa.image.transform`.
- """
- with tf.name_scope(name or "translations_to_projective_transforms"):
+ Args:
+ translations: A 2-element list representing `[dx, dy]` or a matrix of
+ 2-element lists representing `[dx, dy]` to translate for each image
+ (for a batch of images). The rank must be statically known
+ (the shape is not `TensorShape(None)`).
+ name: The name of the op.
+ Returns:
+ A tensor of shape `(num_images, 8)` projective transforms which can be
+ given to `tfa.image.transform`.
+ """
+ with tf.name_scope(name or 'translations_to_projective_transforms'):
translation_or_translations = tf.convert_to_tensor(
- translations, name="translations", dtype=tf.dtypes.float32)
+ translations, name='translations', dtype=tf.dtypes.float32
+ )
if translation_or_translations.get_shape().ndims is None:
raise TypeError(
- "translation_or_translations rank must be statically known")
+ 'translation_or_translations rank must be statically known'
+ )
if len(translation_or_translations.get_shape()) not in (1, 2):
- raise TypeError("Translations should have rank 1 or 2.")
+ raise TypeError('Translations should have rank 1 or 2.')
if len(translation_or_translations.get_shape()) == 1:
translations = translation_or_translations[None]
@@ -372,67 +390,67 @@ def translations_to_projective_transforms(translations: TensorLike,
# where the last entry is implicit.
# Translation matrices are always float32.
return tf.concat(
- values=[
- tf.ones((num_translations, 1), tf.dtypes.float32),
- tf.zeros((num_translations, 1), tf.dtypes.float32),
- -translations[:, 0, None],
- tf.zeros((num_translations, 1), tf.dtypes.float32),
- tf.ones((num_translations, 1), tf.dtypes.float32),
- -translations[:, 1, None],
- tf.zeros((num_translations, 2), tf.dtypes.float32),
- ],
- axis=1,
+ values=[
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 0, None],
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 1, None],
+ tf.zeros((num_translations, 2), tf.dtypes.float32),
+ ],
+ axis=1,
)
@tf.function
def translate(
- images: TensorLike,
- translations: TensorLike,
- interpolation: str = "nearest",
- fill_mode: str = "constant",
- name: Optional[str] = None,
- fill_value: TensorLike = 0.0,
+ images: TensorLike,
+ translations: TensorLike,
+ interpolation: str = 'nearest',
+ fill_mode: str = 'constant',
+ name: Optional[str] = None,
+ fill_value: TensorLike = 0.0,
) -> tf.Tensor:
"""Translate image(s) by the passed vectors(s).
- Args:
- images: A tensor of shape
- `(num_images, num_rows, num_columns, num_channels)` (NHWC),
- `(num_rows, num_columns, num_channels)` (HWC), or
- `(num_rows, num_columns)` (HW). The rank must be statically known (the
- shape is not `TensorShape(None)`).
- translations: A vector representing `[dx, dy]` or (if `images` has rank 4)
- a matrix of length num_images, with a `[dx, dy]` vector for each image
- in the batch.
- interpolation: Interpolation mode. Supported values: "nearest",
- "bilinear".
- fill_mode: Points outside the boundaries of the input are filled according
- to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
- - *reflect*: `(d c b a | a b c d | d c b a)`
- The input is extended by reflecting about the edge of the last pixel.
- - *constant*: `(k k k k | a b c d | k k k k)`
- The input is extended by filling all values beyond the edge with the
- same constant value k = 0.
- - *wrap*: `(a b c d | a b c d | a b c d)`
- The input is extended by wrapping around to the opposite edge.
- - *nearest*: `(a a a a | a b c d | d d d d)`
- The input is extended by the nearest pixel.
- fill_value: a float represents the value to be filled outside the
- boundaries when `fill_mode` is "constant".
- name: The name of the op.
- Returns:
- Image(s) with the same type and shape as `images`, translated by the
- given vector(s). Empty space due to the translation will be filled with
- zeros.
- Raises:
- TypeError: If `images` is an invalid type.
- """
- with tf.name_scope(name or "translate"):
+ Args:
+ images: A tensor of shape
+ `(num_images, num_rows, num_columns, num_channels)` (NHWC),
+ `(num_rows, num_columns, num_channels)` (HWC), or
+ `(num_rows, num_columns)` (HW). The rank must be statically known (the
+ shape is not `TensorShape(None)`).
+ translations: A vector representing `[dx, dy]` or (if `images` has rank 4)
+ a matrix of length num_images, with a `[dx, dy]` vector for each image
+ in the batch.
+ interpolation: Interpolation mode. Supported values: "nearest",
+ "bilinear".
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
+ - *reflect*: `(d c b a | a b c d | d c b a)`
+ The input is extended by reflecting about the edge of the last pixel.
+ - *constant*: `(k k k k | a b c d | k k k k)`
+ The input is extended by filling all values beyond the edge with the
+ same constant value k = 0.
+ - *wrap*: `(a b c d | a b c d | a b c d)`
+ The input is extended by wrapping around to the opposite edge.
+ - *nearest*: `(a a a a | a b c d | d d d d)`
+ The input is extended by the nearest pixel.
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode` is "constant".
+ name: The name of the op.
+ Returns:
+ Image(s) with the same type and shape as `images`, translated by the
+ given vector(s). Empty space due to the translation will be filled with
+ zeros.
+ Raises:
+ TypeError: If `images` is an invalid type.
+ """
+ with tf.name_scope(name or 'translate'):
return transform(
- images,
- translations_to_projective_transforms(translations),
- interpolation=interpolation,
- fill_mode=fill_mode,
- fill_value=fill_value,
+ images,
+ translations_to_projective_transforms(translations),
+ interpolation=interpolation,
+ fill_mode=fill_mode,
+ fill_value=fill_value,
)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
index 66105335b..f782e50a1 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
@@ -7,29 +7,30 @@
import functools
from typing import Dict, Iterator, Tuple
-from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
+from flax import jax_utils
-from algoperf import data_utils
-from algoperf import spec
+from algoperf import data_utils, spec
from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment
TFDS_SPLIT_NAME = {
- 'train': 'train', 'eval_train': 'train', 'validation': 'validation'
+ 'train': 'train',
+ 'eval_train': 'train',
+ 'validation': 'validation',
}
-def _distorted_bounding_box_crop(image_bytes: spec.Tensor,
- rng: spec.RandomState,
- bbox: spec.Tensor,
- min_object_covered: float = 0.1,
- aspect_ratio_range: Tuple[float,
- float] = (0.75,
- 1.33),
- area_range: Tuple[float, float] = (0.05, 1.0),
- max_attempts: int = 100) -> spec.Tensor:
+def _distorted_bounding_box_crop(
+ image_bytes: spec.Tensor,
+ rng: spec.RandomState,
+ bbox: spec.Tensor,
+ min_object_covered: float = 0.1,
+ aspect_ratio_range: Tuple[float, float] = (0.75, 1.33),
+ area_range: Tuple[float, float] = (0.05, 1.0),
+ max_attempts: int = 100,
+) -> spec.Tensor:
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
@@ -57,14 +58,15 @@ def _distorted_bounding_box_crop(image_bytes: spec.Tensor,
"""
shape = tf.io.extract_jpeg_shape(image_bytes)
bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box(
- shape,
- seed=rng,
- bounding_boxes=bbox,
- min_object_covered=min_object_covered,
- aspect_ratio_range=aspect_ratio_range,
- area_range=area_range,
- max_attempts=max_attempts,
- use_image_if_no_bounding_boxes=True)
+ shape,
+ seed=rng,
+ bounding_boxes=bbox,
+ min_object_covered=min_object_covered,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=max_attempts,
+ use_image_if_no_bounding_boxes=True,
+ )
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
@@ -84,8 +86,9 @@ def resize(image: spec.Tensor, image_size: int) -> spec.Tensor:
Returns:
Resized image 'Tensor'.
"""
- return tf.image.resize([image], [image_size, image_size],
- method=tf.image.ResizeMethod.BICUBIC)[0]
+ return tf.image.resize(
+ [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC
+ )[0]
def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool:
@@ -95,80 +98,93 @@ def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool:
return tf.greater_equal(tf.reduce_sum(match), x)
-def _decode_and_random_crop(image_bytes: spec.Tensor,
- rng: spec.RandomState,
- image_size: int,
- aspect_ratio_range: Tuple[float, float],
- area_range: Tuple[float, float],
- resize_size: int) -> spec.Tensor:
+def _decode_and_random_crop(
+ image_bytes: spec.Tensor,
+ rng: spec.RandomState,
+ image_size: int,
+ aspect_ratio_range: Tuple[float, float],
+ area_range: Tuple[float, float],
+ resize_size: int,
+) -> spec.Tensor:
"""Make a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = _distorted_bounding_box_crop(
- image_bytes,
- rng,
- bbox,
- min_object_covered=0.1,
- aspect_ratio_range=aspect_ratio_range,
- area_range=area_range,
- max_attempts=10)
+ image_bytes,
+ rng,
+ bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=10,
+ )
original_shape = tf.io.extract_jpeg_shape(image_bytes)
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
image = tf.cond(
- bad,
- lambda: _decode_and_center_crop(image_bytes, image_size, resize_size),
- lambda: resize(image, image_size))
+ bad,
+ lambda: _decode_and_center_crop(image_bytes, image_size, resize_size),
+ lambda: resize(image, image_size),
+ )
return image
-def _decode_and_center_crop(image_bytes: spec.Tensor,
- image_size: int,
- resize_size: int) -> spec.Tensor:
+def _decode_and_center_crop(
+ image_bytes: spec.Tensor, image_size: int, resize_size: int
+) -> spec.Tensor:
"""Crops to center of image with padding then scales image_size."""
shape = tf.io.extract_jpeg_shape(image_bytes)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
- ((image_size / resize_size) *
- tf.cast(tf.minimum(image_height, image_width), tf.float32)),
- tf.int32)
+ (
+ (image_size / resize_size)
+ * tf.cast(tf.minimum(image_height, image_width), tf.float32)
+ ),
+ tf.int32,
+ )
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
- crop_window = tf.stack([
+ crop_window = tf.stack(
+ [
offset_height,
offset_width,
padded_center_crop_size,
padded_center_crop_size,
- ])
+ ]
+ )
image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
image = resize(image, image_size)
return image
-def normalize_image(image: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float]) -> spec.Tensor:
+def normalize_image(
+ image: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+) -> spec.Tensor:
image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
return image
-def preprocess_for_train(image_bytes: spec.Tensor,
- rng: spec.RandomState,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- aspect_ratio_range: Tuple[float, float],
- area_range: Tuple[float, float],
- image_size: int,
- resize_size: int,
- dtype: tf.DType = tf.float32,
- use_randaug: bool = False,
- randaug_num_layers: int = 2,
- randaug_magnitude: int = 10) -> spec.Tensor:
+def preprocess_for_train(
+ image_bytes: spec.Tensor,
+ rng: spec.RandomState,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ aspect_ratio_range: Tuple[float, float],
+ area_range: Tuple[float, float],
+ image_size: int,
+ resize_size: int,
+ dtype: tf.DType = tf.float32,
+ use_randaug: bool = False,
+ randaug_num_layers: int = 2,
+ randaug_magnitude: int = 10,
+) -> spec.Tensor:
"""Preprocesses the given image for training.
Args:
@@ -182,33 +198,36 @@ def preprocess_for_train(image_bytes: spec.Tensor,
"""
rngs = tf.random.experimental.stateless_split(rng, 3)
- image = _decode_and_random_crop(image_bytes,
- rngs[0],
- image_size,
- aspect_ratio_range,
- area_range,
- resize_size)
+ image = _decode_and_random_crop(
+ image_bytes,
+ rngs[0],
+ image_size,
+ aspect_ratio_range,
+ area_range,
+ resize_size,
+ )
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.stateless_random_flip_left_right(image, seed=rngs[1])
if use_randaug:
image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8)
- image = randaugment.distort_image_with_randaugment(image,
- randaug_num_layers,
- randaug_magnitude,
- rngs[2])
+ image = randaugment.distort_image_with_randaugment(
+ image, randaug_num_layers, randaug_magnitude, rngs[2]
+ )
image = tf.cast(image, tf.float32)
image = normalize_image(image, mean_rgb, stddev_rgb)
image = tf.image.convert_image_dtype(image, dtype=dtype)
return image
-def preprocess_for_eval(image_bytes: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- image_size: int,
- resize_size: int,
- dtype: tf.DType = tf.float32) -> spec.Tensor:
+def preprocess_for_eval(
+ image_bytes: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ image_size: int,
+ resize_size: int,
+ dtype: tf.DType = tf.float32,
+) -> spec.Tensor:
"""Preprocesses the given image for evaluation.
Args:
@@ -229,10 +248,12 @@ def preprocess_for_eval(image_bytes: spec.Tensor,
# Modified from
# github.com/google/init2winit/blob/master/init2winit/dataset_lib/ (cont. below)
# image_preprocessing.py.
-def mixup_tf(key: spec.RandomState,
- inputs: spec.Tensor,
- targets: spec.Tensor,
- alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]:
+def mixup_tf(
+ key: spec.RandomState,
+ inputs: spec.Tensor,
+ targets: spec.Tensor,
+ alpha: float = 0.2,
+) -> Tuple[spec.Tensor, spec.Tensor]:
"""Perform mixup https://arxiv.org/abs/1710.09412.
NOTE: Code taken from https://github.com/google/big_vision with variables
@@ -261,24 +282,26 @@ def mixup_tf(key: spec.RandomState,
return inputs, targets
-def create_split(split,
- dataset_builder,
- rng,
- global_batch_size,
- train,
- image_size,
- resize_size,
- mean_rgb,
- stddev_rgb,
- cache=False,
- repeat_final_dataset=False,
- aspect_ratio_range=(0.75, 4.0 / 3.0),
- area_range=(0.08, 1.0),
- use_mixup=False,
- mixup_alpha=0.1,
- use_randaug=False,
- randaug_num_layers=2,
- randaug_magnitude=10) -> Iterator[Dict[str, spec.Tensor]]:
+def create_split(
+ split,
+ dataset_builder,
+ rng,
+ global_batch_size,
+ train,
+ image_size,
+ resize_size,
+ mean_rgb,
+ stddev_rgb,
+ cache=False,
+ repeat_final_dataset=False,
+ aspect_ratio_range=(0.75, 4.0 / 3.0),
+ area_range=(0.08, 1.0),
+ use_mixup=False,
+ mixup_alpha=0.1,
+ use_randaug=False,
+ randaug_num_layers=2,
+ randaug_magnitude=10,
+) -> Iterator[Dict[str, spec.Tensor]]:
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
shuffle_rng, preprocess_rng, mixup_rng = jax.random.split(rng, 3)
@@ -286,34 +309,35 @@ def decode_example(example_index, example):
dtype = tf.float32
if train:
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
- tf.cast(preprocess_rng, tf.int64), example_index)
-
- image = preprocess_for_train(example['image'],
- per_step_preprocess_rng,
- mean_rgb,
- stddev_rgb,
- aspect_ratio_range,
- area_range,
- image_size,
- resize_size,
- dtype,
- use_randaug,
- randaug_num_layers,
- randaug_magnitude)
+ tf.cast(preprocess_rng, tf.int64), example_index
+ )
+
+ image = preprocess_for_train(
+ example['image'],
+ per_step_preprocess_rng,
+ mean_rgb,
+ stddev_rgb,
+ aspect_ratio_range,
+ area_range,
+ image_size,
+ resize_size,
+ dtype,
+ use_randaug,
+ randaug_num_layers,
+ randaug_magnitude,
+ )
else:
- image = preprocess_for_eval(example['image'],
- mean_rgb,
- stddev_rgb,
- image_size,
- resize_size,
- dtype)
+ image = preprocess_for_eval(
+ example['image'], mean_rgb, stddev_rgb, image_size, resize_size, dtype
+ )
return {'inputs': image, 'targets': example['label']}
ds = dataset_builder.as_dataset(
- split=TFDS_SPLIT_NAME[split],
- decoders={
- 'image': tfds.decode.SkipDecoding(),
- })
+ split=TFDS_SPLIT_NAME[split],
+ decoders={
+ 'image': tfds.decode.SkipDecoding(),
+ },
+ )
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)
@@ -336,18 +360,21 @@ def decode_example(example_index, example):
def mixup_batch(batch_index, batch):
per_batch_mixup_rng = tf.random.experimental.stateless_fold_in(
- mixup_rng, batch_index)
+ mixup_rng, batch_index
+ )
(inputs, targets) = mixup_tf(
- per_batch_mixup_rng,
- batch['inputs'],
- batch['targets'],
- alpha=mixup_alpha)
+ per_batch_mixup_rng,
+ batch['inputs'],
+ batch['targets'],
+ alpha=mixup_alpha,
+ )
batch['inputs'] = inputs
batch['targets'] = targets
return batch
ds = ds.enumerate().map(
- mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE
+ )
else:
raise ValueError('Mixup can only be used for the training split.')
@@ -359,44 +386,48 @@ def mixup_batch(batch_index, batch):
return ds
-def create_input_iter(split: str,
- dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
- rng: spec.RandomState,
- global_batch_size: int,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- image_size: int,
- resize_size: int,
- aspect_ratio_range: Tuple[float, float],
- area_range: Tuple[float, float],
- train: bool,
- cache: bool,
- repeat_final_dataset: bool,
- use_mixup: bool,
- mixup_alpha: float,
- use_randaug: bool) -> Iterator[Dict[str, spec.Tensor]]:
+def create_input_iter(
+ split: str,
+ dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
+ rng: spec.RandomState,
+ global_batch_size: int,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ image_size: int,
+ resize_size: int,
+ aspect_ratio_range: Tuple[float, float],
+ area_range: Tuple[float, float],
+ train: bool,
+ cache: bool,
+ repeat_final_dataset: bool,
+ use_mixup: bool,
+ mixup_alpha: float,
+ use_randaug: bool,
+) -> Iterator[Dict[str, spec.Tensor]]:
ds = create_split(
- split,
- dataset_builder,
- rng,
- global_batch_size,
- train=train,
- image_size=image_size,
- resize_size=resize_size,
- mean_rgb=mean_rgb,
- stddev_rgb=stddev_rgb,
- cache=cache,
- repeat_final_dataset=repeat_final_dataset,
- aspect_ratio_range=aspect_ratio_range,
- area_range=area_range,
- use_mixup=use_mixup,
- mixup_alpha=mixup_alpha,
- use_randaug=use_randaug)
+ split,
+ dataset_builder,
+ rng,
+ global_batch_size,
+ train=train,
+ image_size=image_size,
+ resize_size=resize_size,
+ mean_rgb=mean_rgb,
+ stddev_rgb=stddev_rgb,
+ cache=cache,
+ repeat_final_dataset=repeat_final_dataset,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ use_mixup=use_mixup,
+ mixup_alpha=mixup_alpha,
+ use_randaug=use_randaug,
+ )
it = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
it = jax_utils.prefetch_to_device(it, 2)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
index ffa60b260..84ad4fe21 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
@@ -7,8 +7,8 @@
import functools
from typing import Any, Callable, Optional, Tuple
-from flax import linen as nn
import jax.numpy as jnp
+from flax import linen as nn
from algoperf import spec
@@ -17,12 +17,13 @@
class ResNetBlock(nn.Module):
"""ResNet block."""
+
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
- bn_init_scale: float = 0.
+ bn_init_scale: float = 0.0
@nn.compact
def __call__(self, x: spec.Tensor) -> spec.Tensor:
@@ -35,8 +36,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor:
if residual.shape != y.shape or self.strides != (1, 1):
residual = self.conv(
- self.filters, (1, 1), self.strides, name='Conv_proj')(
- residual)
+ self.filters, (1, 1), self.strides, name='Conv_proj'
+ )(residual)
residual = self.norm(name='BatchNorm_proj')(residual)
return self.act(residual + y)
@@ -44,6 +45,7 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor:
class BottleneckResNetBlock(nn.Module):
"""Bottleneck ResNet block."""
+
filters: int
conv: ModuleDef
norm: ModuleDef
@@ -65,8 +67,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor:
if residual.shape != y.shape or self.strides != (1, 1):
residual = self.conv(
- self.filters * 4, (1, 1), self.strides, name='Conv_proj')(
- residual)
+ self.filters * 4, (1, 1), self.strides, name='Conv_proj'
+ )(residual)
residual = self.norm(name='BatchNorm_proj')(residual)
return self.act(residual + y)
@@ -79,30 +81,35 @@ class ResNet(nn.Module):
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
- bn_init_scale: float = 0.
+ bn_init_scale: float = 0.0
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- update_batch_norm: bool = True,
- use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
+ def __call__(
+ self,
+ x: spec.Tensor,
+ update_batch_norm: bool = True,
+ use_running_average_bn: Optional[bool] = None,
+ ) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
- nn.BatchNorm,
- use_running_average=use_running_average_bn,
- momentum=0.9,
- epsilon=1e-5,
- dtype=self.dtype)
+ nn.BatchNorm,
+ use_running_average=use_running_average_bn,
+ momentum=0.9,
+ epsilon=1e-5,
+ dtype=self.dtype,
+ )
x = conv(
- self.num_filters, (7, 7), (2, 2),
- padding=[(3, 3), (3, 3)],
- name='Conv_init')(
- x)
+ self.num_filters,
+ (7, 7),
+ (2, 2),
+ padding=[(3, 3), (3, 3)],
+ name='Conv_init',
+ )(x)
x = norm(name='BatchNorm_init')(x)
x = self.act(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
@@ -110,23 +117,23 @@ def __call__(self,
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
- self.num_filters * 2**i,
- strides=strides,
- conv=conv,
- norm=norm,
- act=self.act,
- bn_init_scale=self.bn_init_scale)(
- x)
+ self.num_filters * 2**i,
+ strides=strides,
+ conv=conv,
+ norm=norm,
+ act=self.act,
+ bn_init_scale=self.bn_init_scale,
+ )(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(
- self.num_classes,
- kernel_init=nn.initializers.normal(),
- dtype=self.dtype)(
- x)
+ self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
+ )(x)
return x
ResNet18 = functools.partial(
- ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
+ ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock
+)
ResNet50 = functools.partial(
- ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
+ ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock
+)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
index c68e2de33..87e218a0c 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
@@ -9,16 +9,15 @@
import tensorflow as tf
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \
- rotate_img
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \
- transform
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \
- translate
+from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
+ rotate_img,
+ transform,
+ translate,
+)
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
-_MAX_LEVEL = 10.
+_MAX_LEVEL = 10.0
def blend(image1, image2, factor):
@@ -86,10 +85,12 @@ def cutout(image, pad_size, replace=0):
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform(
- shape=[], minval=0, maxval=image_height, dtype=tf.int32)
+ shape=[], minval=0, maxval=image_height, dtype=tf.int32
+ )
cutout_center_width = tf.random.uniform(
- shape=[], minval=0, maxval=image_width, dtype=tf.int32)
+ shape=[], minval=0, maxval=image_width, dtype=tf.int32
+ )
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
@@ -97,20 +98,18 @@ def cutout(image, pad_size, replace=0):
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
cutout_shape = [
- image_height - (lower_pad + upper_pad),
- image_width - (left_pad + right_pad),
+ image_height - (lower_pad + upper_pad),
+ image_width - (left_pad + right_pad),
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
- tf.zeros(cutout_shape, dtype=image.dtype),
- padding_dims,
- constant_values=1)
+ tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1
+ )
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
- tf.equal(mask, 0),
- tf.ones_like(image, dtype=image.dtype) * replace,
- image)
+ tf.equal(mask, 0), tf.ones_like(image, dtype=image.dtype) * replace, image
+ )
return image
@@ -204,7 +203,7 @@ def shear_x(image, level, replace):
# with a matrix form of:
# [1 level
# 0 1].
- image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.])
+ image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
return unwrap(image, replace)
@@ -214,7 +213,7 @@ def shear_y(image, level, replace):
# with a matrix form of:
# [1 0
# level 1].
- image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.])
+ image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0])
return unwrap(image, replace)
@@ -264,9 +263,12 @@ def sharpness(image, factor):
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
- kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
- dtype=tf.float32,
- shape=[3, 3, 1, 1]) / 13.
+ kernel = (
+ tf.constant(
+ [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1]
+ )
+ / 13.0
+ )
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
@@ -274,7 +276,8 @@ def sharpness(image, factor):
# Some augmentation that uses depth-wise conv will cause crashing when
# training on GPU.
degenerate = tf.nn.depthwise_conv2d(
- image, kernel, strides, padding='VALID', dilations=[1, 1])
+ image, kernel, strides, padding='VALID', dilations=[1, 1]
+ )
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
@@ -316,9 +319,10 @@ def build_lut(histo, step):
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result = tf.cond(
- tf.equal(step, 0),
- lambda: im,
- lambda: tf.gather(build_lut(histo, step), im))
+ tf.equal(step, 0),
+ lambda: im,
+ lambda: tf.gather(build_lut(histo, step), im),
+ )
return tf.cast(result, tf.uint8)
@@ -373,9 +377,10 @@ def unwrap(image, replace):
# Where they are zero, fill them in with 'replace'.
flattened_image = tf.where(
- tf.equal(alpha_channel, 0),
- tf.ones_like(flattened_image, dtype=image.dtype) * replace,
- flattened_image)
+ tf.equal(alpha_channel, 0),
+ tf.ones_like(flattened_image, dtype=image.dtype) * replace,
+ flattened_image,
+ )
image = tf.reshape(flattened_image, image_shape)
image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
@@ -383,22 +388,22 @@ def unwrap(image, replace):
NAME_TO_FUNC = {
- 'AutoContrast': autocontrast,
- 'Equalize': equalize,
- 'Invert': invert,
- 'Rotate': rotate,
- 'Posterize': posterize,
- 'Solarize': solarize,
- 'SolarizeAdd': solarize_add,
- 'Color': color,
- 'Contrast': contrast,
- 'Brightness': brightness,
- 'Sharpness': sharpness,
- 'ShearX': shear_x,
- 'ShearY': shear_y,
- 'TranslateX': translate_x,
- 'TranslateY': translate_y,
- 'Cutout': cutout,
+ 'AutoContrast': autocontrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': rotate,
+ 'Posterize': posterize,
+ 'Solarize': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'Contrast': contrast,
+ 'Brightness': brightness,
+ 'Sharpness': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x,
+ 'TranslateY': translate_y,
+ 'Cutout': cutout,
}
@@ -410,7 +415,7 @@ def _randomly_negate_tensor(tensor):
def _rotate_level_to_arg(level):
- level = (level / _MAX_LEVEL) * 30.
+ level = (level / _MAX_LEVEL) * 30.0
level = _randomly_negate_tensor(level)
return (level,)
@@ -435,47 +440,28 @@ def _translate_level_to_arg(level, translate_const):
def level_to_arg(cutout_const, translate_const):
return {
- 'AutoContrast':
- lambda level: (),
- 'Equalize':
- lambda level: (),
- 'Invert':
- lambda level: (),
- 'Rotate':
- _rotate_level_to_arg,
- 'Posterize':
- lambda level: (int((level / _MAX_LEVEL) * 4),),
- 'Solarize':
- lambda level: (int((level / _MAX_LEVEL) * 256),),
- 'SolarizeAdd':
- lambda level: (int((level / _MAX_LEVEL) * 110),),
- 'Color':
- _enhance_level_to_arg,
- 'Contrast':
- _enhance_level_to_arg,
- 'Brightness':
- _enhance_level_to_arg,
- 'Sharpness':
- _enhance_level_to_arg,
- 'ShearX':
- _shear_level_to_arg,
- 'ShearY':
- _shear_level_to_arg,
- 'Cutout':
- lambda level: (int((level / _MAX_LEVEL) * cutout_const),),
- 'TranslateX':
- lambda level: _translate_level_to_arg(level, translate_const),
- 'TranslateY':
- lambda level: _translate_level_to_arg(level, translate_const),
+ 'AutoContrast': lambda level: (),
+ 'Equalize': lambda level: (),
+ 'Invert': lambda level: (),
+ 'Rotate': _rotate_level_to_arg,
+ 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
+ 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
+ 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
+ 'Color': _enhance_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'Cutout': lambda level: (int((level / _MAX_LEVEL) * cutout_const),),
+ 'TranslateX': lambda level: _translate_level_to_arg(level, translate_const),
+ 'TranslateY': lambda level: _translate_level_to_arg(level, translate_const),
}
-def _parse_policy_info(name,
- prob,
- level,
- replace_value,
- cutout_const,
- translate_const):
+def _parse_policy_info(
+ name, prob, level, replace_value, cutout_const, translate_const
+):
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
args = level_to_arg(cutout_const, translate_const)[name](level)
@@ -514,45 +500,49 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key):
"""
replace_value = [128] * 3
available_ops = [
- 'AutoContrast',
- 'Equalize',
- 'Invert',
- 'Rotate',
- 'Posterize',
- 'Solarize',
- 'Color',
- 'Contrast',
- 'Brightness',
- 'Sharpness',
- 'ShearX',
- 'ShearY',
- 'TranslateX',
- 'TranslateY',
- 'Cutout',
- 'SolarizeAdd',
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'Posterize',
+ 'Solarize',
+ 'Color',
+ 'Contrast',
+ 'Brightness',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateX',
+ 'TranslateY',
+ 'Cutout',
+ 'SolarizeAdd',
]
for layer_num in range(num_layers):
key = tf.random.experimental.stateless_fold_in(key, layer_num)
- op_to_select = tf.random.stateless_uniform([],
- seed=key,
- maxval=len(available_ops),
- dtype=tf.int32)
+ op_to_select = tf.random.stateless_uniform(
+ [], seed=key, maxval=len(available_ops), dtype=tf.int32
+ )
random_magnitude = float(magnitude)
with tf.name_scope('randaug_layer_{}'.format(layer_num)):
- for (i, op_name) in enumerate(available_ops):
+ for i, op_name in enumerate(available_ops):
key = tf.random.experimental.stateless_fold_in(key, i)
- prob = tf.random.stateless_uniform([],
- seed=key,
- minval=0.2,
- maxval=0.8,
- dtype=tf.float32)
- func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
- replace_value, cutout_const=40,
- translate_const=100)
+ prob = tf.random.stateless_uniform(
+ [], seed=key, minval=0.2, maxval=0.8, dtype=tf.float32
+ )
+ func, _, args = _parse_policy_info(
+ op_name,
+ prob,
+ random_magnitude,
+ replace_value,
+ cutout_const=40,
+ translate_const=100,
+ )
image = tf.cond(
- tf.equal(i, op_to_select),
- lambda selected_func=func,
- selected_args=args: selected_func(image, *selected_args),
- lambda: image)
+ tf.equal(i, op_to_select),
+ lambda selected_func=func, selected_args=args: selected_func(
+ image, *selected_args
+ ),
+ lambda: image,
+ )
return image
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
index 4ec3937b8..c3035c212 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
@@ -9,70 +9,75 @@
import math
from typing import Dict, Iterator, Optional, Tuple
-from flax import jax_utils
-from flax import linen as nn
-from flax.core import pop
import jax
-from jax import lax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
+from flax import jax_utils
+from flax import linen as nn
+from flax.core import pop
+from jax import lax
-from algoperf import param_utils
+from algoperf import param_utils, spec
from algoperf import random_utils as prng
-from algoperf import spec
from algoperf.workloads.imagenet_resnet import imagenet_v2
-from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
-from algoperf.workloads.imagenet_resnet.imagenet_jax import models
-from algoperf.workloads.imagenet_resnet.workload import \
- BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax import (
+ input_pipeline,
+ models,
+)
+from algoperf.workloads.imagenet_resnet.workload import (
+ BaseImagenetResNetWorkload,
+)
class ImagenetResNetWorkload(BaseImagenetResNetWorkload):
-
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
if split == 'test':
np_iter = imagenet_v2.get_imagenet_v2_iter(
- data_dir,
- global_batch_size,
- mean_rgb=self.train_mean,
- stddev_rgb=self.train_stddev,
- image_size=self.center_crop_size,
- resize_size=self.resize_size)
+ data_dir,
+ global_batch_size,
+ mean_rgb=self.train_mean,
+ stddev_rgb=self.train_stddev,
+ image_size=self.center_crop_size,
+ resize_size=self.resize_size,
+ )
return itertools.cycle(np_iter)
ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir)
train = split == 'train'
ds = input_pipeline.create_input_iter(
- split,
- ds_builder,
- data_rng,
- global_batch_size,
- self.train_mean,
- self.train_stddev,
- self.center_crop_size,
- self.resize_size,
- self.aspect_ratio_range,
- self.scale_ratio_range,
- train=train,
- cache=not train if cache is None else cache,
- repeat_final_dataset=repeat_final_dataset,
- use_mixup=use_mixup,
- mixup_alpha=0.2,
- use_randaug=use_randaug)
+ split,
+ ds_builder,
+ data_rng,
+ global_batch_size,
+ self.train_mean,
+ self.train_stddev,
+ self.center_crop_size,
+ self.resize_size,
+ self.aspect_ratio_range,
+ self.scale_ratio_range,
+ train=train,
+ cache=not train if cache is None else cache,
+ repeat_final_dataset=repeat_final_dataset,
+ use_mixup=use_mixup,
+ mixup_alpha=0.2,
+ use_randaug=use_randaug,
+ )
return ds
def sync_batch_stats(
- self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
+ self, model_state: spec.ModelAuxiliaryState
+ ) -> spec.ModelAuxiliaryState:
"""Sync the batch statistics across replicas."""
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics and
@@ -83,13 +88,9 @@ def sync_batch_stats(
return new_model_state
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
+ self,
+ rng: spec.RandomState,
+ ) -> spec.ModelInitState:
model_cls = getattr(models, 'ResNet50')
if self.use_silu and self.use_gelu:
@@ -102,15 +103,17 @@ def init_model_fn(
act_fnc = nn.relu
model = model_cls(
- num_classes=self._num_classes,
- act=act_fnc,
- bn_init_scale=self.bn_init_scale,
- dtype=jnp.float32)
+ num_classes=self._num_classes,
+ act=act_fnc,
+ bn_init_scale=self.bn_init_scale,
+ dtype=jnp.float32,
+ )
self._model = model
input_shape = (1, 224, 224, 3)
- variables = jax.jit(model.init)({'params': rng},
- jnp.ones(input_shape, model.dtype))
- model_state, params = pop(variables, "params")
+ variables = jax.jit(model.init)(
+ {'params': rng}, jnp.ones(input_shape, model.dtype)
+ )
+ model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
@@ -121,63 +124,70 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, 0),
- static_broadcasted_argnums=(0,))
- def _eval_model(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, 0),
+ static_broadcasted_argnums=(0,),
+ )
+ def _eval_model(
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[str, spec.Tensor]:
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng=rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng=rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
return self._compute_metrics(logits, batch['targets'], weights)
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=['batch_stats'],
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, new_model_state
else:
logits = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=False,
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=False,
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, model_state
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -186,12 +196,14 @@ def loss_fn(
"""
if label_batch.shape[-1] != self._num_classes:
one_hot_labels = jax.nn.one_hot(
- label_batch, num_classes=self._num_classes)
+ label_batch, num_classes=self._num_classes
+ )
else:
one_hot_labels = label_batch
smoothed_labels = optax.smooth_labels(one_hot_labels, label_smoothing)
per_example_losses = -jnp.sum(
- smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1)
+ smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -200,36 +212,37 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
- def _compute_metrics(self,
- logits: spec.Tensor,
- labels: spec.Tensor,
- weights: spec.Tensor) -> Dict[str, spec.Tensor]:
+ def _compute_metrics(
+ self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
+ ) -> Dict[str, spec.Tensor]:
if weights is None:
weights = jnp.ones(len(logits))
summed_loss = self.loss_fn(labels, logits, weights)['summed']
# not accuracy, but nr. of correct predictions
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
- 'loss': summed_loss,
- 'accuracy': accuracy,
+ 'loss': summed_loss,
+ 'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
del global_step
if model_state is not None:
# Sync batch statistics across replicas before evaluating.
@@ -239,13 +252,14 @@ def _eval_model_on_split(self,
# We already repeat the dataset indefinitely in tf.data.
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split=split,
- global_batch_size=global_batch_size,
- data_dir=data_dir,
- cache=True,
- repeat_final_dataset=True,
- num_batches=num_batches)
+ data_rng,
+ split=split,
+ global_batch_size=global_batch_size,
+ data_dir=data_dir,
+ cache=True,
+ repeat_final_dataset=True,
+ num_batches=num_batches,
+ )
eval_metrics = {}
for bi in range(num_batches):
@@ -253,22 +267,21 @@ def _eval_model_on_split(self,
step_eval_rngs = prng.split(eval_rng, jax.local_device_count())
batch = next(self._eval_iters[split])
# We already average these metrics across devices inside _compute_metrics.
- synced_metrics = self._eval_model(params,
- batch,
- model_state,
- step_eval_rngs)
+ synced_metrics = self._eval_model(
+ params, batch, model_state, step_eval_rngs
+ )
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value
- eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples),
- eval_metrics)
+ eval_metrics = jax.tree.map(
+ lambda x: float(x[0] / num_examples), eval_metrics
+ )
return eval_metrics
class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload):
-
@property
def use_silu(self) -> bool:
return True
@@ -283,7 +296,6 @@ def test_target_value(self) -> float:
class ImagenetResNetGELUWorkload(ImagenetResNetWorkload):
-
@property
def use_gelu(self) -> bool:
return True
@@ -298,7 +310,6 @@ def test_target_value(self) -> float:
class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload):
-
@property
def bn_init_scale(self) -> float:
return 8.0
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
index aba9e671f..c980faa06 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
@@ -8,51 +8,55 @@
from typing import Any, Callable, List, Optional, Type, Union
import torch
-from torch import nn
-from torch import Tensor
+from torch import Tensor, nn
from algoperf import spec
from algoperf.init_utils import pytorch_default_init
-def conv3x3(in_planes: int,
- out_planes: int,
- stride: int = 1,
- groups: int = 1,
- dilation: int = 1) -> nn.Conv2d:
+def conv3x3(
+ in_planes: int,
+ out_planes: int,
+ stride: int = 1,
+ groups: int = 1,
+ dilation: int = 1,
+) -> nn.Conv2d:
"""3x3 convolution with padding."""
return nn.Conv2d(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=dilation,
- groups=groups,
- bias=False,
- dilation=dilation)
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution."""
return nn.Conv2d(
- in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
+ )
class BasicBlock(nn.Module):
"""ResNet block."""
+
expansion: int = 1
def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- act_fnc: nn.Module = nn.ReLU(inplace=True)
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ act_fnc: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
if norm_layer is None:
@@ -60,7 +64,7 @@ def __init__(
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ raise NotImplementedError('Dilation > 1 not supported in BasicBlock')
# Both self.conv1 and self.downsample layers downsample
# the input when stride != 1.
self.conv1 = conv3x3(inplanes, planes, stride)
@@ -92,24 +96,25 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class Bottleneck(nn.Module):
"""Bottleneck ResNet block."""
+
expansion: int = 4
def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- act_fnc: nn.Module = nn.ReLU(inplace=True)
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ act_fnc: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
- width = int(planes * (base_width / 64.)) * groups
+ width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample
# the input when stride != 1.
self.conv1 = conv1x1(inplanes, width)
@@ -146,18 +151,19 @@ def forward(self, x: Tensor) -> Tensor:
class ResNet(nn.Module):
-
- def __init__(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- layers: List[int],
- num_classes: int = 1000,
- zero_init_residual: bool = True,
- groups: int = 1,
- width_per_group: int = 64,
- replace_stride_with_dilation: Optional[List[bool]] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- act_fnc: nn.Module = nn.ReLU(inplace=True),
- bn_init_scale: float = 0.) -> None:
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = True,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ act_fnc: nn.Module = nn.ReLU(inplace=True),
+ bn_init_scale: float = 0.0,
+ ) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -171,37 +177,42 @@ def __init__(self,
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
- 'replace_stride_with_dilation should be None '
- f'or a 3-element tuple, got {replace_stride_with_dilation}')
+ 'replace_stride_with_dilation should be None '
+ f'or a 3-element tuple, got {replace_stride_with_dilation}'
+ )
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
- 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
+ )
self.bn1 = norm_layer(self.inplanes)
self.act_fnc = act_fnc
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, self.act_fnc, 64, layers[0])
self.layer2 = self._make_layer(
- block,
- self.act_fnc,
- 128,
- layers[1],
- stride=2,
- dilate=replace_stride_with_dilation[0])
+ block,
+ self.act_fnc,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0],
+ )
self.layer3 = self._make_layer(
- block,
- self.act_fnc,
- 256,
- layers[2],
- stride=2,
- dilate=replace_stride_with_dilation[1])
+ block,
+ self.act_fnc,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1],
+ )
self.layer4 = self._make_layer(
- block,
- self.act_fnc,
- 512,
- layers[3],
- stride=2,
- dilate=replace_stride_with_dilation[2])
+ block,
+ self.act_fnc,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2],
+ )
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
@@ -212,7 +223,7 @@ def __init__(self,
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.fc.weight, std=1e-2)
- nn.init.constant_(self.fc.bias, 0.)
+ nn.init.constant_(self.fc.bias, 0.0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros,
@@ -226,13 +237,15 @@ def __init__(self,
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, bn_init_scale)
- def _make_layer(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- act_fnc: nn.Module,
- planes: int,
- blocks: int,
- stride: int = 1,
- dilate: bool = False) -> nn.Sequential:
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ act_fnc: nn.Module,
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,
+ ) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
@@ -241,34 +254,41 @@ def _make_layer(self,
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = torch.nn.Sequential(
- collections.OrderedDict([
- ("conv", conv1x1(self.inplanes, planes * block.expansion,
- stride)),
- ("bn", norm_layer(planes * block.expansion)),
- ]))
+ collections.OrderedDict(
+ [
+ ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)),
+ ('bn', norm_layer(planes * block.expansion)),
+ ]
+ )
+ )
layers = []
layers.append(
- block(self.inplanes,
- planes,
- stride,
- downsample,
- self.groups,
- self.base_width,
- previous_dilation,
- norm_layer,
- act_fnc))
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer,
+ act_fnc,
+ )
+ )
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
- block(
- self.inplanes,
- planes,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm_layer=norm_layer,
- act_fnc=act_fnc))
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ act_fnc=act_fnc,
+ )
+ )
return nn.Sequential(*layers)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
index c7a98e77a..1c5c0d952 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
@@ -11,8 +11,8 @@
import PIL
import torch
from torch import Tensor
-from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
+from torchvision.transforms import functional as F
from algoperf import spec
@@ -24,8 +24,8 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor:
# Double the pad size to match Jax implementation.
pad_size = pad_size * 2
- x0 = int(max(0, x0 - pad_size / 2.))
- y0 = int(max(0, y0 - pad_size / 2.))
+ x0 = int(max(0, x0 - pad_size / 2.0))
+ y0 = int(max(0, y0 - pad_size / 2.0))
x1 = int(min(image_width, x0 + pad_size))
y1 = int(min(image_height, y0 + pad_size))
xy = (x0, y0, x1, y1)
@@ -36,7 +36,7 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor:
def solarize(img: spec.Tensor, threshold: float) -> spec.Tensor:
img = np.array(img)
- new_img = np.where(img < threshold, img, 255. - img)
+ new_img = np.where(img < threshold, img, 255.0 - img)
return PIL.Image.fromarray(new_img.astype(np.uint8))
@@ -49,54 +49,56 @@ def solarize_add(img: spec.Tensor, addition: int = 0) -> spec.Tensor:
return PIL.Image.fromarray(new_img)
-def _apply_op(img: spec.Tensor,
- op_name: str,
- magnitude: float,
- interpolation: InterpolationMode,
- fill: Optional[List[float]]) -> spec.Tensor:
+def _apply_op(
+ img: spec.Tensor,
+ op_name: str,
+ magnitude: float,
+ interpolation: InterpolationMode,
+ fill: Optional[List[float]],
+) -> spec.Tensor:
if op_name == 'ShearX':
# Magnitude should be arctan(magnitude).
img = F.affine(
- img,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[math.degrees(math.atan(magnitude)), 0.0],
- interpolation=interpolation,
- fill=fill,
- center=[0, 0],
+ img,
+ angle=0.0,
+ translate=[0, 0],
+ scale=1.0,
+ shear=[math.degrees(math.atan(magnitude)), 0.0],
+ interpolation=interpolation,
+ fill=fill,
+ center=[0, 0],
)
elif op_name == 'ShearY':
# Magnitude should be arctan(magnitude).
img = F.affine(
- img,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[0.0, math.degrees(math.atan(magnitude))],
- interpolation=interpolation,
- fill=fill,
- center=[0, 0],
+ img,
+ angle=0.0,
+ translate=[0, 0],
+ scale=1.0,
+ shear=[0.0, math.degrees(math.atan(magnitude))],
+ interpolation=interpolation,
+ fill=fill,
+ center=[0, 0],
)
elif op_name == 'TranslateX':
img = F.affine(
- img,
- angle=0.0,
- translate=[int(magnitude), 0],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill,
+ img,
+ angle=0.0,
+ translate=[int(magnitude), 0],
+ scale=1.0,
+ interpolation=interpolation,
+ shear=[0.0, 0.0],
+ fill=fill,
)
elif op_name == 'TranslateY':
img = F.affine(
- img,
- angle=0.0,
- translate=[0, int(magnitude)],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill,
+ img,
+ angle=0.0,
+ translate=[0, int(magnitude)],
+ scale=1.0,
+ interpolation=interpolation,
+ shear=[0.0, 0.0],
+ fill=fill,
)
elif op_name == 'Rotate':
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
@@ -131,33 +133,32 @@ def _apply_op(img: spec.Tensor,
def ops_space() -> Dict[str, Tuple[spec.Tensor, bool]]:
return {
- # op_name: (magnitudes, signed)
- 'ShearX': (torch.tensor(0.3), True),
- 'ShearY': (torch.tensor(0.3), True),
- 'TranslateX': (torch.tensor(100), True),
- 'TranslateY': (torch.tensor(100), True),
- 'Rotate': (torch.tensor(30), True),
- 'Brightness': (torch.tensor(1.9), False),
- 'Color': (torch.tensor(1.9), False),
- 'Contrast': (torch.tensor(1.9), False),
- 'Sharpness': (torch.tensor(1.9), False),
- 'Posterize': (torch.tensor(4), False),
- 'Solarize': (torch.tensor(256), False),
- 'SolarizeAdd': (torch.tensor(110), False),
- 'AutoContrast': (torch.tensor(0.0), False),
- 'Equalize': (torch.tensor(0.0), False),
- 'Invert': (torch.tensor(0.0), False),
- 'Cutout': (torch.tensor(40.0), False),
+ # op_name: (magnitudes, signed)
+ 'ShearX': (torch.tensor(0.3), True),
+ 'ShearY': (torch.tensor(0.3), True),
+ 'TranslateX': (torch.tensor(100), True),
+ 'TranslateY': (torch.tensor(100), True),
+ 'Rotate': (torch.tensor(30), True),
+ 'Brightness': (torch.tensor(1.9), False),
+ 'Color': (torch.tensor(1.9), False),
+ 'Contrast': (torch.tensor(1.9), False),
+ 'Sharpness': (torch.tensor(1.9), False),
+ 'Posterize': (torch.tensor(4), False),
+ 'Solarize': (torch.tensor(256), False),
+ 'SolarizeAdd': (torch.tensor(110), False),
+ 'AutoContrast': (torch.tensor(0.0), False),
+ 'Equalize': (torch.tensor(0.0), False),
+ 'Invert': (torch.tensor(0.0), False),
+ 'Cutout': (torch.tensor(40.0), False),
}
class RandAugment(torch.nn.Module):
-
def __init__(
- self,
- num_ops: int = 2,
- interpolation: InterpolationMode = InterpolationMode.NEAREST,
- fill: Optional[List[float]] = None,
+ self,
+ num_ops: int = 2,
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
+ fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.num_ops = num_ops
@@ -183,5 +184,6 @@ def forward(self, img: spec.Tensor) -> spec.Tensor:
# With 50% prob turn the magnitude negative.
magnitude *= -1.0
img = _apply_op(
- img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+ img, op_name, magnitude, interpolation=self.interpolation, fill=fill
+ )
return img
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
index ed29271f3..85a35dc45 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
@@ -16,22 +16,21 @@
from torchvision import transforms
from torchvision.datasets.folder import ImageFolder
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
import algoperf.random_utils as prng
+from algoperf import data_utils, param_utils, pytorch_utils, spec
from algoperf.workloads.imagenet_resnet import imagenet_v2
from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50
-from algoperf.workloads.imagenet_resnet.workload import \
- BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.workload import (
+ BaseImagenetResNetWorkload,
+)
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
def imagenet_v2_to_torch(
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ batch: Dict[str, spec.Tensor],
+) -> Dict[str, spec.Tensor]:
# Slice off the part of the batch for this device and then transpose from
# [N, H, W, C] to [N, C, H, W]. Only transfer the inputs to GPU.
new_batch = {}
@@ -48,7 +47,6 @@ def imagenet_v2_to_torch(
class ImagenetResNetWorkload(BaseImagenetResNetWorkload):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Is set in submission_runner.py for workloads with PyTorch evaluation
@@ -59,7 +57,8 @@ def __init__(self, *args, **kwargs) -> None:
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -67,60 +66,68 @@ def eval_num_workers(self, eval_num_workers: int):
self._eval_num_workers = eval_num_workers
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
del repeat_final_dataset
if split == 'test':
np_iter = imagenet_v2.get_imagenet_v2_iter(
- data_dir,
- global_batch_size,
- mean_rgb=self.train_mean,
- stddev_rgb=self.train_stddev,
- image_size=self.center_crop_size,
- resize_size=self.resize_size)
+ data_dir,
+ global_batch_size,
+ mean_rgb=self.train_mean,
+ stddev_rgb=self.train_stddev,
+ image_size=self.center_crop_size,
+ resize_size=self.resize_size,
+ )
return map(imagenet_v2_to_torch, itertools.cycle(np_iter))
is_train = split == 'train'
normalize = transforms.Normalize(
- mean=[i / 255. for i in self.train_mean],
- std=[i / 255. for i in self.train_stddev])
+ mean=[i / 255.0 for i in self.train_mean],
+ std=[i / 255.0 for i in self.train_stddev],
+ )
if is_train:
transform_config = [
- transforms.RandomResizedCrop(
- self.center_crop_size,
- scale=self.scale_ratio_range,
- ratio=self.aspect_ratio_range),
- transforms.RandomHorizontalFlip(),
+ transforms.RandomResizedCrop(
+ self.center_crop_size,
+ scale=self.scale_ratio_range,
+ ratio=self.aspect_ratio_range,
+ ),
+ transforms.RandomHorizontalFlip(),
]
if use_randaug:
transform_config.append(randaugment.RandAugment())
transform_config.extend([transforms.ToTensor(), normalize])
transform_config = transforms.Compose(transform_config)
else:
- transform_config = transforms.Compose([
+ transform_config = transforms.Compose(
+ [
transforms.Resize(self.resize_size),
transforms.CenterCrop(self.center_crop_size),
transforms.ToTensor(),
normalize,
- ])
+ ]
+ )
folder = 'train' if 'train' in split else 'val'
dataset = ImageFolder(
- os.path.join(data_dir, folder), transform=transform_config)
+ os.path.join(data_dir, folder), transform=transform_config
+ )
if split == 'eval_train':
indices = list(range(self.num_train_examples))
random.Random(int(data_rng[0])).shuffle(indices)
- dataset = torch.utils.data.Subset(dataset,
- indices[:self.num_eval_train_examples])
+ dataset = torch.utils.data.Subset(
+ dataset, indices[: self.num_eval_train_examples]
+ )
sampler = None
if USE_PYTORCH_DDP:
@@ -131,37 +138,34 @@ def _build_dataset(
if USE_PYTORCH_DDP:
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True
+ )
else:
sampler = data_utils.DistributedEvalSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False
+ )
dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=ds_iter_batch_size,
- shuffle=not USE_PYTORCH_DDP and is_train,
- sampler=sampler,
- num_workers=4 if is_train else self.eval_num_workers,
- pin_memory=True,
- drop_last=is_train,
- persistent_workers=is_train)
+ dataset,
+ batch_size=ds_iter_batch_size,
+ shuffle=not USE_PYTORCH_DDP and is_train,
+ sampler=sampler,
+ num_workers=4 if is_train else self.eval_num_workers,
+ pin_memory=True,
+ drop_last=is_train,
+ persistent_workers=is_train,
+ )
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
dataloader = data_utils.cycle(
- dataloader,
- custom_sampler=USE_PYTORCH_DDP,
- use_mixup=use_mixup,
- mixup_alpha=0.2)
+ dataloader,
+ custom_sampler=USE_PYTORCH_DDP,
+ use_mixup=use_mixup,
+ mixup_alpha=0.2,
+ )
return dataloader
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.use_silu and self.use_gelu:
@@ -188,34 +192,40 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['fc.weight', 'fc.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = 0.0,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
+ del dropout_rate
model = params
if mode == spec.ForwardPassMode.EVAL:
if update_batch_norm:
raise ValueError(
- 'Batch norm statistics cannot be updated during evaluation.')
+ 'Batch norm statistics cannot be updated during evaluation.'
+ )
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
- functools.partial(
- pytorch_utils.update_batch_norm_fn,
- update_batch_norm=update_batch_norm))
+ functools.partial(
+ pytorch_utils.update_batch_norm_fn,
+ update_batch_norm=update_batch_norm,
+ )
+ )
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
@@ -226,11 +236,12 @@ def model_fn(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -238,10 +249,11 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = F.cross_entropy(
- logits_batch,
- label_batch,
- reduction='none',
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ reduction='none',
+ label_smoothing=label_smoothing,
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -250,15 +262,14 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
- def _compute_metrics(self,
- logits: spec.Tensor,
- labels: spec.Tensor,
- weights: spec.Tensor) -> Dict[str, spec.Tensor]:
+ def _compute_metrics(
+ self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
+ ) -> Dict[str, spec.Tensor]:
"""Return the mean accuracy and loss as a dict."""
if weights is None:
weights = torch.ones(len(logits), device=DEVICE)
@@ -268,15 +279,17 @@ def _compute_metrics(self,
summed_loss = self.loss_fn(labels, logits, weights)['summed']
return {'accuracy': accuracy, 'loss': summed_loss}
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
@@ -284,31 +297,33 @@ def _eval_model_on_split(self,
is_test = split == 'test'
# These iterators repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split=split,
- global_batch_size=global_batch_size,
- data_dir=data_dir,
- cache=is_test,
- repeat_final_dataset=is_test)
+ data_rng,
+ split=split,
+ global_batch_size=global_batch_size,
+ data_dir=data_dir,
+ cache=is_test,
+ repeat_final_dataset=is_test,
+ )
total_metrics = {
- 'accuracy': torch.tensor(0., device=DEVICE),
- 'loss': torch.tensor(0., device=DEVICE),
+ 'accuracy': torch.tensor(0.0, device=DEVICE),
+ 'loss': torch.tensor(0.0, device=DEVICE),
}
num_batches = int(math.ceil(num_examples / global_batch_size))
for _ in range(num_batches):
batch = next(self._eval_iters[split])
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- model_rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ model_rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
batch_metrics = self._compute_metrics(logits, batch['targets'], weights)
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
@@ -317,7 +332,6 @@ def _eval_model_on_split(self,
class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload):
-
@property
def use_silu(self) -> bool:
return True
@@ -332,7 +346,6 @@ def test_target_value(self) -> float:
class ImagenetResNetGELUWorkload(ImagenetResNetWorkload):
-
@property
def use_gelu(self) -> bool:
return True
@@ -347,7 +360,6 @@ def test_target_value(self) -> float:
class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload):
-
@property
def bn_init_scale(self) -> float:
return 8.0
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py
index 84d364586..7a8e38f02 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py
@@ -8,37 +8,38 @@
import tensorflow_datasets as tfds
-from algoperf import data_utils
-from algoperf import spec
+from algoperf import data_utils, spec
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
-def get_imagenet_v2_iter(data_dir: str,
- global_batch_size: int,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- image_size: int,
- resize_size: int) -> Iterator[Dict[str, spec.Tensor]]:
+def get_imagenet_v2_iter(
+ data_dir: str,
+ global_batch_size: int,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ image_size: int,
+ resize_size: int,
+) -> Iterator[Dict[str, spec.Tensor]]:
"""Always caches and repeats indefinitely."""
ds = tfds.load(
- 'imagenet_v2/matched-frequency:3.0.0',
- split='test',
- data_dir=data_dir,
- decoders={
- 'image': tfds.decode.SkipDecoding(),
- })
+ 'imagenet_v2/matched-frequency:3.0.0',
+ split='test',
+ data_dir=data_dir,
+ decoders={
+ 'image': tfds.decode.SkipDecoding(),
+ },
+ )
def _decode_example(example: Dict[str, float]) -> Dict[str, float]:
- image = input_pipeline.preprocess_for_eval(example['image'],
- mean_rgb,
- stddev_rgb,
- image_size,
- resize_size)
+ image = input_pipeline.preprocess_for_eval(
+ example['image'], mean_rgb, stddev_rgb, image_size, resize_size
+ )
return {'inputs': image, 'targets': example['label']}
ds = ds.map(_decode_example, num_parallel_calls=16)
ds = ds.batch(global_batch_size)
shard_pad_fn = functools.partial(
- data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size)
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ )
it = map(shard_pad_fn, iter(ds))
return it
diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py
index 83fe97108..ef696e328 100644
--- a/algoperf/workloads/imagenet_resnet/workload.py
+++ b/algoperf/workloads/imagenet_resnet/workload.py
@@ -7,7 +7,6 @@
class BaseImagenetResNetWorkload(spec.Workload):
-
_num_classes: int = 1000
@property
@@ -15,8 +14,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'accuracy'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/accuracy'] > self.validation_target_value
@property
@@ -58,8 +58,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -109,38 +110,37 @@ def eval_period_time_sec(self) -> int:
return 510 # 8.5 minutes.
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
raise NotImplementedError
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
if split == 'test':
if not cache:
raise ValueError('cache must be True for split=test.')
if not repeat_final_dataset:
raise ValueError('repeat_final_dataset must be True for split=test.')
- return self._build_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset)
+ return self._build_dataset(
+ data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
+ )
@property
def step_hint(self) -> int:
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 7ce3a0395..5e38acd8b 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -7,24 +7,29 @@
from typing import Optional, Sequence, Union
-from flax import linen as nn
import jax.numpy as jnp
+from flax import linen as nn
from algoperf import spec
+from algoperf.jax_utils import Dropout
+DROPOUT_RATE = 0.0
-def posemb_sincos_2d(h: int,
- w: int,
- width: int,
- temperature: int = 10_000.,
- dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
+
+def posemb_sincos_2d(
+ h: int,
+ w: int,
+ width: int,
+ temperature: int = 10_000.0,
+ dtype: jnp.dtype = jnp.float32,
+) -> spec.Tensor:
"""Follows the MoCo v3 logic."""
- y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence
+ y, x = jnp.mgrid[:h, :w] # pylint: disable=unpacking-non-sequence
if width % 4 != 0:
raise ValueError('Width must be mult of 4 for sincos posemb.')
omega = jnp.arange(width // 4) / (width // 4 - 1)
- omega = 1. / (temperature**omega)
+ omega = 1.0 / (temperature**omega)
y = jnp.einsum('m,d->md', y.flatten(), omega)
x = jnp.einsum('m,d->md', x.flatten(), omega)
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
@@ -33,16 +38,19 @@ def posemb_sincos_2d(h: int,
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
use_glu: bool = False
- dropout_rate: float = 0.0
+ dropout_rate: float = DROPOUT_RATE
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate=DROPOUT_RATE
+ ) -> spec.Tensor:
"""Applies Transformer MlpBlock module."""
inits = {
- 'kernel_init': nn.initializers.xavier_uniform(),
- 'bias_init': nn.initializers.normal(stddev=1e-6),
+ 'kernel_init': nn.initializers.xavier_uniform(),
+ 'bias_init': nn.initializers.normal(stddev=1e-6),
}
d = x.shape[2]
@@ -53,13 +61,14 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y
- x = nn.Dropout(rate=self.dropout_rate)(x, train)
+ x = Dropout(dropout_rate)(x, train, rate=dropout_rate)
x = nn.Dense(d, **inits)(x)
return x
class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
use_glu: bool = False
@@ -67,45 +76,46 @@ class Encoder1DBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
+ ) -> spec.Tensor:
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- kernel_init=nn.initializers.xavier_uniform(),
- deterministic=train,
- name='MultiHeadDotProductAttention_1')(
- y)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=train,
+ name='MultiHeadDotProductAttention_1',
+ )(y)
+ y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=self.dropout_rate,
- name='MlpBlock_3')(y, train)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3'
+ )(y, train, dropout_rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
else:
y = x
y = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- kernel_init=nn.initializers.xavier_uniform(),
- deterministic=train,
- name='MultiHeadDotProductAttention_1')(
- y)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=train,
+ name='MultiHeadDotProductAttention_1',
+ )(y)
+ y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)
y = x
y = MlpBlock(
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=self.dropout_rate,
- name='MlpBlock_3')(y, train)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ mlp_dim=self.mlp_dim,
+ use_glu=self.use_glu,
+ name='MlpBlock_3',
+ dropout_rate=dropout_rate,
+ )(y, train, dropout_rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train)(rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -114,25 +124,26 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
+
depth: int
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
- dropout_rate: float = 0.0
use_glu: bool = False
use_post_layer_norm: bool = False
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE
+ ) -> spec.Tensor:
# Input Encoder
for lyr in range(self.depth):
- block = Encoder1DBlock(
- name=f'encoderblock_{lyr}',
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=self.dropout_rate)
- x = block(x, train)
+ x = Encoder1DBlock(
+ name=f'encoderblock_{lyr}',
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ )(x, train=train, dropout_rate=dropout_rate)
if not self.use_post_layer_norm:
return nn.LayerNorm(name='encoder_layernorm')(x)
else:
@@ -141,24 +152,27 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
+ dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x):
+ def __call__(self, x, dropout_rate=DROPOUT_RATE):
n, _, d = x.shape
- probe = self.param('probe',
- nn.initializers.xavier_uniform(), (1, 1, d),
- x.dtype)
+ probe = self.param(
+ 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype
+ )
probe = jnp.tile(probe, [n, 1, 1])
x = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(probe, x)
+ num_heads=self.num_heads,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(probe, x)
y = nn.LayerNorm()(x)
- x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
+ x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y)
return x[:, 0]
@@ -172,29 +186,30 @@ class ViT(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
rep_size: Union[int, bool] = True
- dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
+ dropout_rate: [float] = DROPOUT_RATE
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
use_post_layer_norm: bool = False
use_map: bool = False
- def get_posemb(self,
- seqshape: tuple,
- width: int,
- dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
+ def get_posemb(
+ self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32
+ ) -> spec.Tensor:
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
@nn.compact
- def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, *, train: bool = False, dropout_rate=DROPOUT_RATE
+ ) -> spec.Tensor:
# Patch extraction
x = nn.Conv(
- self.width,
- self.patch_size,
- strides=self.patch_size,
- padding='VALID',
- name='conv_patch_extract')(
- x)
+ self.width,
+ self.patch_size,
+ strides=self.patch_size,
+ padding='VALID',
+ name='conv_patch_extract',
+ )(x)
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
@@ -202,23 +217,23 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
# Add posemb before adding extra token.
x = x + self.get_posemb((h, w), c, x.dtype)
- dropout_rate = self.dropout_rate
- if dropout_rate is None:
- dropout_rate = 0.0
- x = nn.Dropout(rate=dropout_rate)(x, not train)
+ x = Dropout(dropout_rate)(x, not train, rate=dropout_rate)
x = Encoder(
- depth=self.depth,
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate,
- name='Transformer')(
- x, train=not train)
+ depth=self.depth,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ name='Transformer',
+ )(x, train=not train, dropout_rate=dropout_rate)
if self.use_map:
- x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
+ x = MAPHead(
+ num_heads=self.num_heads,
+ mlp_dim=self.mlp_dim,
+ dropout_rate=dropout_rate,
+ )(x, dropout_rate=dropout_rate)
else:
x = jnp.mean(x, axis=1)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 35a6c46be..1637a2123 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -2,47 +2,44 @@
from typing import Dict, Optional, Tuple
+import jax
+import jax.numpy as jnp
from flax import jax_utils
from flax import linen as nn
from flax.core import pop
-import jax
-import jax.numpy as jnp
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetWorkload
+from algoperf import param_utils, spec
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload,
+)
from algoperf.workloads.imagenet_vit.imagenet_jax import models
-from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload
-from algoperf.workloads.imagenet_vit.workload import decode_variant
+from algoperf.workloads.imagenet_vit.workload import (
+ BaseImagenetVitWorkload,
+ decode_variant,
+)
# Make sure we inherit from the ViT base workload first.
class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
-
- def initialized(self, key: spec.RandomState,
- model: nn.Module) -> spec.ModelInitState:
+ def initialized(
+ self, key: spec.RandomState, model: nn.Module
+ ) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
- params_rng, dropout_rng = jax.random.split(key)
- variables = jax.jit(
- model.init)({'params': params_rng, 'dropout': dropout_rng},
- jnp.ones(input_shape))
- model_state, params = pop(variables, "params")
+ params_rng, _ = jax.random.split(key)
+ variables = jax.jit(model.init)(
+ {'params': params_rng}, jnp.ones(input_shape)
+ )
+ model_state, params = pop(variables, 'params')
return params, model_state
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
self._model = models.ViT(
- dropout_rate=dropout_rate,
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'),
+ )
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -54,44 +51,54 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'head'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
+ del use_running_average_bn
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train)
+ logits = self._model.apply(
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate,
+ )
return logits, None
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
model_state = None
- return super()._eval_model_on_split(split,
- num_examples,
- global_batch_size,
- params,
- model_state,
- rng,
- data_dir,
- global_step)
+ return super()._eval_model_on_split(
+ split,
+ num_examples,
+ global_batch_size,
+ params,
+ model_state,
+ rng,
+ data_dir,
+ global_step,
+ )
class ImagenetVitGluWorkload(ImagenetVitWorkload):
-
@property
def use_glu(self) -> bool:
return True
@@ -106,7 +113,6 @@ def test_target_value(self) -> float:
class ImagenetVitPostLNWorkload(ImagenetVitWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return True
@@ -121,7 +127,6 @@ def test_target_value(self) -> float:
class ImagenetVitMapWorkload(ImagenetVitWorkload):
-
@property
def use_map(self) -> bool:
return True
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
index fcf0992d3..fc2a3cd46 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
@@ -9,25 +9,29 @@
from typing import Any, Optional, Tuple, Union
import torch
-from torch import nn
import torch.nn.functional as F
+from torch import nn
-from algoperf import init_utils
-from algoperf import spec
+from algoperf import init_utils, spec
from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention
+DROPOUT_RATE = 0.0
+
-def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
+def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.0) -> spec.Tensor:
"""Follows the MoCo v3 logic."""
_, width, h, w = patches.shape
device = patches.device
- y, x = torch.meshgrid(torch.arange(h, device=device),
- torch.arange(w, device=device), indexing='ij')
+ y, x = torch.meshgrid(
+ torch.arange(h, device=device),
+ torch.arange(w, device=device),
+ indexing='ij',
+ )
if width % 4 != 0:
raise ValueError('Width must be mult of 4 for sincos posemb.')
omega = torch.arange(width // 4, device=device) / (width // 4 - 1)
- omega = 1. / (temperature**omega)
+ omega = 1.0 / (temperature**omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
@@ -38,21 +42,19 @@ class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
def __init__(
- self,
- width: int,
- mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- use_glu: bool = False,
- dropout_rate: float = 0.0) -> None:
+ self,
+ width: int,
+ mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
+ use_glu: bool = False,
+ ) -> None:
super().__init__()
self.width = width
self.mlp_dim = mlp_dim or 4 * width
self.use_glu = use_glu
- self.dropout_rate = dropout_rate
self.linear1 = nn.Linear(self.width, self.mlp_dim)
self.act_fnc = nn.GELU(approximate='tanh')
- self.dropout = nn.Dropout(self.dropout_rate)
if self.use_glu:
self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim)
@@ -70,7 +72,7 @@ def reset_parameters(self) -> None:
if module.bias is not None:
module.bias.data.normal_(std=1e-6)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
x = self.linear1(x)
x = self.act_fnc(x)
@@ -78,7 +80,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
y = self.glu_linear(x)
x = x * y
- x = self.dropout(x)
+ x = F.dropout(x, dropout_rate, training=self.training)
x = self.linear2(x)
return x
@@ -86,17 +88,15 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class SelfAttention(nn.Module):
"""Self-attention special case of multi-head dot-product attention."""
- def __init__(self,
- width: int,
- num_heads: int = 8,
- dropout_rate: float = 0.0) -> None:
+ def __init__(self, width: int, num_heads: int = 8) -> None:
super().__init__()
self.width = width
self.num_heads = num_heads
assert width % num_heads == 0, (
- 'Memory dimension must be divisible by number of heads.')
+ 'Memory dimension must be divisible by number of heads.'
+ )
self.head_dim = int(width / num_heads)
self.all_head_dim = self.num_heads * self.head_dim
@@ -104,7 +104,6 @@ def __init__(self,
self.query = nn.Linear(self.width, self.all_head_dim)
self.key = nn.Linear(self.width, self.all_head_dim)
self.value = nn.Linear(self.width, self.all_head_dim)
- self.dropout = nn.Dropout(dropout_rate)
self.out = nn.Linear(self.width, self.width)
self.reset_parameters()
@@ -113,14 +112,14 @@ def reset_parameters(self) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight.data)
if module.bias is not None:
- nn.init.constant_(module.bias.data, 0.)
+ nn.init.constant_(module.bias.data, 0.0)
def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
mixed_query_layer = self.query(x)
key_layer = self.transpose_for_scores(self.key(x))
@@ -131,7 +130,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
attention_scores = attention_scores / math.sqrt(self.head_dim)
attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = self.dropout(attention_probs)
+ attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
@@ -144,13 +143,14 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
- def __init__(self,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12,
- use_glu: bool = False,
- use_post_layer_norm: bool = False,
- dropout_rate: float = 0.0) -> None:
+ def __init__(
+ self,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ ) -> None:
super().__init__()
self.width = width
@@ -161,35 +161,32 @@ def __init__(self,
self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6)
self.self_attention1 = SelfAttention(self.width, self.num_heads)
- self.dropout = nn.Dropout(dropout_rate)
self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
self.mlp3 = MlpBlock(
- width=self.width,
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=dropout_rate)
+ width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu
+ )
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
if not self.use_post_layer_norm:
y = self.layer_norm0(x)
- y = self.self_attention1(y)
- y = self.dropout(y)
+ y = self.self_attention1(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
y = self.layer_norm2(x)
- y = self.mlp3(y)
- y = self.dropout(y)
+ y = self.mlp3(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
else:
y = x
- y = self.self_attention1(y)
- y = self.dropout(y)
+ y = self.self_attention1(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
x = self.layer_norm0(x)
y = x
- y = self.mlp3(y)
- y = self.dropout(y)
+ y = self.mlp3(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
x = self.layer_norm2(x)
return x
@@ -198,14 +195,15 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
- def __init__(self,
- depth: int,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12,
- use_glu: bool = False,
- use_post_layer_norm: bool = False,
- dropout_rate: float = 0.0) -> None:
+ def __init__(
+ self,
+ depth: int,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ ) -> None:
super().__init__()
self.depth = depth
@@ -215,24 +213,28 @@ def __init__(self,
self.use_glu = use_glu
self.use_post_layer_norm = use_post_layer_norm
- self.net = nn.ModuleList([
- Encoder1DBlock(self.width,
- self.mlp_dim,
- self.num_heads,
- self.use_glu,
- self.use_post_layer_norm,
- dropout_rate) for _ in range(depth)
- ])
+ self.net = nn.ModuleList(
+ [
+ Encoder1DBlock(
+ self.width,
+ self.mlp_dim,
+ self.num_heads,
+ self.use_glu,
+ self.use_post_layer_norm,
+ )
+ for _ in range(depth)
+ ]
+ )
if not self.use_post_layer_norm:
self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6)
else:
self.encoder_norm = None
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
# Input Encoder.
for block in self.net:
- x = block(x)
+ x = block(x, dropout_rate)
if not self.use_post_layer_norm:
return self.encoder_norm(x)
else:
@@ -242,10 +244,9 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
- def __init__(self,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12):
+ def __init__(
+ self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12
+ ):
super().__init__()
self.width = width
self.mlp_dim = mlp_dim
@@ -255,17 +256,18 @@ def __init__(self,
nn.init.xavier_uniform_(self.probe.data)
self.mha = MultiheadAttention(
- self.width, num_heads=self.num_heads, self_attn=False, bias=True)
+ self.width, num_heads=self.num_heads, self_attn=False, bias=True
+ )
self.layer_norm = nn.LayerNorm(self.width, eps=1e-6)
self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
n, _, _ = x.shape
probe = torch.tile(self.probe, [n, 1, 1])
- x = self.mha(probe, x)[0]
+ x = self.mha(probe, x, dropout_rate=dropout_rate)[0]
y = self.layer_norm(x)
- x = x + self.mlp(y)
+ x = x + self.mlp(y, dropout_rate)
return x[:, 0]
@@ -277,23 +279,21 @@ class ViT(nn.Module):
channels: int = 3
def __init__(
- self,
- num_classes: int = 1000,
- patch_size: Tuple[int, int] = (16, 16),
- width: int = 768,
- depth: int = 12,
- mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- num_heads: int = 12,
- rep_size: Union[int, bool] = True,
- dropout_rate: Optional[float] = 0.0,
- head_zeroinit: bool = True,
- use_glu: bool = False,
- use_post_layer_norm: bool = False,
- use_map: bool = False,
- dtype: Any = torch.float32) -> None:
+ self,
+ num_classes: int = 1000,
+ patch_size: Tuple[int, int] = (16, 16),
+ width: int = 768,
+ depth: int = 12,
+ mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
+ num_heads: int = 12,
+ rep_size: Union[int, bool] = True,
+ head_zeroinit: bool = True,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ use_map: bool = False,
+ dtype: Any = torch.float32,
+ ) -> None:
super().__init__()
- if dropout_rate is None:
- dropout_rate = 0.0
self.num_classes = num_classes
self.patch_size = patch_size
@@ -313,21 +313,21 @@ def __init__(
self.pre_logits = nn.Linear(self.width, rep_size)
self.conv_patch_extract = nn.Conv2d(
- self.channels,
- self.width,
- self.patch_size,
- stride=self.patch_size,
- padding='valid')
- self.dropout = nn.Dropout(p=dropout_rate)
+ self.channels,
+ self.width,
+ self.patch_size,
+ stride=self.patch_size,
+ padding='valid',
+ )
self.encoder = Encoder(
- depth=self.depth,
- width=self.width,
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate)
+ depth=self.depth,
+ width=self.width,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ )
if self.num_classes:
self.head = nn.Linear(self.width, self.num_classes)
@@ -347,15 +347,17 @@ def reset_parameters(self) -> None:
if self.num_classes:
if self.head_zeroinit:
- nn.init.constant_(self.head.weight.data, 0.)
- nn.init.constant_(self.head.bias.data, 0.)
+ nn.init.constant_(self.head.weight.data, 0.0)
+ nn.init.constant_(self.head.bias.data, 0.0)
else:
init_utils.pytorch_default_init(self.head)
def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
return posemb_sincos_2d(x).type(self.dtype)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(
+ self, x: spec.Tensor, dropout_rate: float = DROPOUT_RATE
+ ) -> spec.Tensor:
# Patch extraction.
x = self.conv_patch_extract(x)
@@ -367,11 +369,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2)
x = x + pes
- x = self.dropout(x)
- x = self.encoder(x)
+ x = F.dropout(x, dropout_rate, training=self.training)
+ x = self.encoder(x, dropout_rate)
if self.use_map:
- x = self.map(x)
+ x = self.map(x, dropout_rate)
else:
x = torch.mean(x, dim=1)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index 97bb38515..9c6faf70b 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -1,40 +1,35 @@
"""ImageNet ViT workload implemented in PyTorch."""
import contextlib
-from typing import Dict, Optional, Tuple
+from typing import Dict, Tuple
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetWorkload
+from algoperf import param_utils, pytorch_utils, spec
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload,
+)
from algoperf.workloads.imagenet_vit.imagenet_pytorch import models
-from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload
-from algoperf.workloads.imagenet_vit.workload import decode_variant
+from algoperf.workloads.imagenet_vit.workload import (
+ BaseImagenetVitWorkload,
+ decode_variant,
+)
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
# Make sure we inherit from the ViT base workload first.
class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
-
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = models.ViT(
- dropout_rate=dropout_rate,
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'),
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -49,13 +44,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['head.weight', 'head.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -69,18 +66,20 @@ def model_fn(
model.train()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
- logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
+ logits_batch = model(
+ augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate,
+ )
return logits_batch, None
class ImagenetVitGluWorkload(ImagenetVitWorkload):
-
@property
def use_glu(self) -> bool:
return True
@@ -95,7 +94,6 @@ def test_target_value(self) -> float:
class ImagenetVitPostLNWorkload(ImagenetVitWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return True
@@ -110,7 +108,6 @@ def test_target_value(self) -> float:
class ImagenetVitMapWorkload(ImagenetVitWorkload):
-
@property
def use_map(self) -> bool:
return True
diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py
index f249ddee8..2a0070ba4 100644
--- a/algoperf/workloads/imagenet_vit/workload.py
+++ b/algoperf/workloads/imagenet_vit/workload.py
@@ -3,8 +3,9 @@
from typing import Dict, Iterator, Optional
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.workload import \
- BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.workload import (
+ BaseImagenetResNetWorkload,
+)
def decode_variant(variant: str) -> Dict[str, int]:
@@ -12,46 +13,52 @@ def decode_variant(variant: str) -> Dict[str, int]:
v, patch = variant.split('/')
return {
- # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
- 'width': {
- 'Ti': 192,
- 'S': 384,
- 'M': 512,
- 'B': 768,
- 'L': 1024,
- 'H': 1280,
- 'g': 1408,
- 'G': 1664,
- }[v],
- 'depth': {
- 'Ti': 12,
- 'S': 12,
- 'M': 12,
- 'B': 12,
- 'L': 24,
- 'H': 32,
- 'g': 40,
- 'G': 48,
- }[v],
- 'mlp_dim': {
- 'Ti': 768,
- 'S': 1536,
- 'M': 2048,
- 'B': 3072,
- 'L': 4096,
- 'H': 5120,
- 'g': 6144,
- 'G': 8192,
- }[v],
- 'num_heads': {
- 'Ti': 3, 'S': 6, 'M': 8, 'B': 12, 'L': 16, 'H': 16, 'g': 16, 'G': 16
- }[v],
- 'patch_size': (int(patch), int(patch)),
+ # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
+ 'width': {
+ 'Ti': 192,
+ 'S': 384,
+ 'M': 512,
+ 'B': 768,
+ 'L': 1024,
+ 'H': 1280,
+ 'g': 1408,
+ 'G': 1664,
+ }[v],
+ 'depth': {
+ 'Ti': 12,
+ 'S': 12,
+ 'M': 12,
+ 'B': 12,
+ 'L': 24,
+ 'H': 32,
+ 'g': 40,
+ 'G': 48,
+ }[v],
+ 'mlp_dim': {
+ 'Ti': 768,
+ 'S': 1536,
+ 'M': 2048,
+ 'B': 3072,
+ 'L': 4096,
+ 'H': 5120,
+ 'g': 6144,
+ 'G': 8192,
+ }[v],
+ 'num_heads': {
+ 'Ti': 3,
+ 'S': 6,
+ 'M': 8,
+ 'B': 12,
+ 'L': 16,
+ 'H': 16,
+ 'g': 16,
+ 'G': 16,
+ }[v],
+ 'patch_size': (int(patch), int(patch)),
}
class BaseImagenetVitWorkload(BaseImagenetResNetWorkload):
-
@property
def validation_target_value(self) -> float:
return 1 - 0.22691 # 0.77309
@@ -88,25 +95,28 @@ def eval_period_time_sec(self) -> int:
return 7 * 60 # 7 mins.
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
# We use mixup and Randaugment for ViT workloads.
use_mixup = use_randaug = split == 'train'
- return super()._build_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset,
- use_mixup,
- use_randaug)
+ return super()._build_dataset(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ cache,
+ repeat_final_dataset,
+ use_mixup,
+ use_randaug,
+ )
@property
def step_hint(self) -> int:
diff --git a/algoperf/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py
index 1310e7b59..23ce8e3b7 100644
--- a/algoperf/workloads/librispeech_conformer/input_pipeline.py
+++ b/algoperf/workloads/librispeech_conformer/input_pipeline.py
@@ -4,13 +4,12 @@
import csv
-from absl import logging
import numpy as np
import torch
+from absl import logging
class LibriSpeechDataset(torch.utils.data.Dataset):
-
def __init__(self, split, data_dir):
super().__init__()
self.data_dir = data_dir
@@ -38,13 +37,14 @@ def __getitem__(self, index):
audio_paddings = np.zeros_like(audio, dtype=np.float32)
audio_paddings = np.pad(
- audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0)
+ audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0
+ )
audio = np.pad(audio, (0, 320000 - audio.shape[0]), constant_values=0.0)
target_paddings = np.zeros_like(targets, dtype=np.float32)
target_paddings = np.pad(
- target_paddings, (0, 256 - target_paddings.shape[0]),
- constant_values=1.0)
+ target_paddings, (0, 256 - target_paddings.shape[0]), constant_values=1.0
+ )
targets = np.pad(targets, (0, 256 - targets.shape[0]), constant_values=0)
audio = audio.astype(np.float32)
audio_paddings = audio_paddings.astype(np.float32)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
index 9f45434d9..531e68a45 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
@@ -10,185 +10,186 @@
from typing import Any, Optional, Union
-from flax import linen as nn
-from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
+from flax import linen as nn
+from flax import struct
# mel spectrum constants.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
LIBRISPEECH_MEAN_VECTOR = [
- -7.6047816276550293,
- -7.1206226348876953,
- -6.8864245414733887,
- -6.8705768585205078,
- -6.9667720794677734,
- -7.1084094047546387,
- -6.9528026580810547,
- -6.783994197845459,
- -6.6195521354675293,
- -6.4876265525817871,
- -6.4120659828186035,
- -6.394047737121582,
- -6.4244871139526367,
- -6.3993711471557617,
- -6.5158271789550781,
- -6.7137999534606934,
- -6.8476877212524414,
- -6.9885001182556152,
- -6.9221386909484863,
- -7.146148681640625,
- -7.2040400505065918,
- -7.0537552833557129,
- -7.3140382766723633,
- -7.1223249435424805,
- -7.30251407623291,
- -7.1212143898010254,
- -7.2425732612609863,
- -7.1730537414550781,
- -7.0979413986206055,
- -7.088747501373291,
- -6.9849910736083984,
- -6.8787732124328613,
- -6.7602753639221191,
- -6.6300945281982422,
- -6.5145769119262695,
- -6.4245057106018066,
- -6.356513500213623,
- -6.31787633895874,
- -6.2660770416259766,
- -6.2468328475952148,
- -6.2821526527404785,
- -6.1908388137817383,
- -6.2484354972839355,
- -6.1472640037536621,
- -6.0924725532531738,
- -6.0171003341674805,
- -5.9250402450561523,
- -5.8535833358764648,
- -5.8209109306335449,
- -5.8118929862976074,
- -5.80783748626709,
- -5.7714629173278809,
- -5.7453732490539551,
- -5.7705655097961426,
- -5.7765641212463379,
- -5.7831673622131348,
- -5.7954087257385254,
- -5.7994823455810547,
- -5.8023476600646973,
- -5.8047118186950684,
- -5.8168182373046875,
- -5.8844799995422363,
- -5.9727106094360352,
- -6.0444660186767578,
- -6.1284866333007812,
- -6.2257585525512695,
- -6.3157496452331543,
- -6.39061164855957,
- -6.4928598403930664,
- -6.5498456954956055,
- -6.6054320335388184,
- -6.6508378982543945,
- -6.66917610168457,
- -6.6726889610290527,
- -6.684234619140625,
- -6.6974577903747559,
- -6.75471830368042,
- -6.7949142456054688,
- -6.8634209632873535,
- -6.94186544418335
+ -7.6047816276550293,
+ -7.1206226348876953,
+ -6.8864245414733887,
+ -6.8705768585205078,
+ -6.9667720794677734,
+ -7.1084094047546387,
+ -6.9528026580810547,
+ -6.783994197845459,
+ -6.6195521354675293,
+ -6.4876265525817871,
+ -6.4120659828186035,
+ -6.394047737121582,
+ -6.4244871139526367,
+ -6.3993711471557617,
+ -6.5158271789550781,
+ -6.7137999534606934,
+ -6.8476877212524414,
+ -6.9885001182556152,
+ -6.9221386909484863,
+ -7.146148681640625,
+ -7.2040400505065918,
+ -7.0537552833557129,
+ -7.3140382766723633,
+ -7.1223249435424805,
+ -7.30251407623291,
+ -7.1212143898010254,
+ -7.2425732612609863,
+ -7.1730537414550781,
+ -7.0979413986206055,
+ -7.088747501373291,
+ -6.9849910736083984,
+ -6.8787732124328613,
+ -6.7602753639221191,
+ -6.6300945281982422,
+ -6.5145769119262695,
+ -6.4245057106018066,
+ -6.356513500213623,
+ -6.31787633895874,
+ -6.2660770416259766,
+ -6.2468328475952148,
+ -6.2821526527404785,
+ -6.1908388137817383,
+ -6.2484354972839355,
+ -6.1472640037536621,
+ -6.0924725532531738,
+ -6.0171003341674805,
+ -5.9250402450561523,
+ -5.8535833358764648,
+ -5.8209109306335449,
+ -5.8118929862976074,
+ -5.80783748626709,
+ -5.7714629173278809,
+ -5.7453732490539551,
+ -5.7705655097961426,
+ -5.7765641212463379,
+ -5.7831673622131348,
+ -5.7954087257385254,
+ -5.7994823455810547,
+ -5.8023476600646973,
+ -5.8047118186950684,
+ -5.8168182373046875,
+ -5.8844799995422363,
+ -5.9727106094360352,
+ -6.0444660186767578,
+ -6.1284866333007812,
+ -6.2257585525512695,
+ -6.3157496452331543,
+ -6.39061164855957,
+ -6.4928598403930664,
+ -6.5498456954956055,
+ -6.6054320335388184,
+ -6.6508378982543945,
+ -6.66917610168457,
+ -6.6726889610290527,
+ -6.684234619140625,
+ -6.6974577903747559,
+ -6.75471830368042,
+ -6.7949142456054688,
+ -6.8634209632873535,
+ -6.94186544418335,
]
LIBRISPEECH_STD_VECTOR = [
- 3.4353282451629639,
- 3.5962932109832764,
- 3.7012472152709961,
- 3.7369205951690674,
- 3.7535104751586914,
- 3.693629264831543,
- 3.6922497749328613,
- 3.7641522884368896,
- 3.8419716358184814,
- 3.8999848365783691,
- 3.9294240474700928,
- 3.9317409992218018,
- 3.9139585494995117,
- 3.9031598567962646,
- 3.8691999912261963,
- 3.8155081272125244,
- 3.7644970417022705,
- 3.7099106311798096,
- 3.6965086460113525,
- 3.6003766059875488,
- 3.5493226051330566,
- 3.5465121269226074,
- 3.45003604888916,
- 3.4712812900543213,
- 3.4084610939025879,
- 3.4408135414123535,
- 3.4104881286621094,
- 3.4217638969421387,
- 3.4312851428985596,
- 3.4199209213256836,
- 3.4305806159973145,
- 3.4382665157318115,
- 3.4580366611480713,
- 3.4817991256713867,
- 3.4958710670471191,
- 3.5036792755126953,
- 3.5047574043273926,
- 3.4988734722137451,
- 3.493056058883667,
- 3.4822943210601807,
- 3.459430456161499,
- 3.4612770080566406,
- 3.4559063911437988,
- 3.4755423069000244,
- 3.4971549510955811,
- 3.5326557159423828,
- 3.5705199241638184,
- 3.5920312404632568,
- 3.596907377243042,
- 3.5913500785827637,
- 3.5865931510925293,
- 3.5826809406280518,
- 3.5837743282318115,
- 3.5895791053771973,
- 3.5819313526153564,
- 3.5837869644165039,
- 3.5861184597015381,
- 3.5889589786529541,
- 3.592214822769165,
- 3.5939455032348633,
- 3.5856630802154541,
- 3.5884113311767578,
- 3.5921022891998291,
- 3.5870490074157715,
- 3.5806570053100586,
- 3.5731067657470703,
- 3.5617532730102539,
- 3.54980731010437,
- 3.5527374744415283,
- 3.5475366115570068,
- 3.5387849807739258,
- 3.5256178379058838,
- 3.5031836032867432,
- 3.4922726154327393,
- 3.4879646301269531,
- 3.4725594520568848,
- 3.4558389186859131,
- 3.4351828098297119,
- 3.4284293651580811,
- 3.4299170970916748
+ 3.4353282451629639,
+ 3.5962932109832764,
+ 3.7012472152709961,
+ 3.7369205951690674,
+ 3.7535104751586914,
+ 3.693629264831543,
+ 3.6922497749328613,
+ 3.7641522884368896,
+ 3.8419716358184814,
+ 3.8999848365783691,
+ 3.9294240474700928,
+ 3.9317409992218018,
+ 3.9139585494995117,
+ 3.9031598567962646,
+ 3.8691999912261963,
+ 3.8155081272125244,
+ 3.7644970417022705,
+ 3.7099106311798096,
+ 3.6965086460113525,
+ 3.6003766059875488,
+ 3.5493226051330566,
+ 3.5465121269226074,
+ 3.45003604888916,
+ 3.4712812900543213,
+ 3.4084610939025879,
+ 3.4408135414123535,
+ 3.4104881286621094,
+ 3.4217638969421387,
+ 3.4312851428985596,
+ 3.4199209213256836,
+ 3.4305806159973145,
+ 3.4382665157318115,
+ 3.4580366611480713,
+ 3.4817991256713867,
+ 3.4958710670471191,
+ 3.5036792755126953,
+ 3.5047574043273926,
+ 3.4988734722137451,
+ 3.493056058883667,
+ 3.4822943210601807,
+ 3.459430456161499,
+ 3.4612770080566406,
+ 3.4559063911437988,
+ 3.4755423069000244,
+ 3.4971549510955811,
+ 3.5326557159423828,
+ 3.5705199241638184,
+ 3.5920312404632568,
+ 3.596907377243042,
+ 3.5913500785827637,
+ 3.5865931510925293,
+ 3.5826809406280518,
+ 3.5837743282318115,
+ 3.5895791053771973,
+ 3.5819313526153564,
+ 3.5837869644165039,
+ 3.5861184597015381,
+ 3.5889589786529541,
+ 3.592214822769165,
+ 3.5939455032348633,
+ 3.5856630802154541,
+ 3.5884113311767578,
+ 3.5921022891998291,
+ 3.5870490074157715,
+ 3.5806570053100586,
+ 3.5731067657470703,
+ 3.5617532730102539,
+ 3.54980731010437,
+ 3.5527374744415283,
+ 3.5475366115570068,
+ 3.5387849807739258,
+ 3.5256178379058838,
+ 3.5031836032867432,
+ 3.4922726154327393,
+ 3.4879646301269531,
+ 3.4725594520568848,
+ 3.4558389186859131,
+ 3.4351828098297119,
+ 3.4284293651580811,
+ 3.4299170970916748,
]
@struct.dataclass
class LibrispeechPreprocessingConfig:
"""Config to hold all preprocessing options for librispeech dataset."""
+
sample_rate: float = 16000.0
frame_size_ms: float = 25.0
frame_step_ms: float = 10.0
@@ -208,8 +209,9 @@ class LibrispeechPreprocessingConfig:
def _hertz_to_mel(frequencies_hertz):
"""Convert hertz to mel."""
- return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz /
- _MEL_BREAK_FREQUENCY_HERTZ))
+ return _MEL_HIGH_FREQUENCY_Q * jnp.log(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)
+ )
def _pad_end_length(num_timesteps, frame_step, frame_size):
@@ -221,11 +223,13 @@ def _pad_end_length(num_timesteps, frame_step, frame_size):
return padded_length - num_timesteps
-def frame(x,
- frame_length: int,
- frame_step: int,
- pad_end: bool = False,
- pad_value: Union[int, float] = 0.0):
+def frame(
+ x,
+ frame_length: int,
+ frame_step: int,
+ pad_end: bool = False,
+ pad_value: Union[int, float] = 0.0,
+):
"""Slides a window and extract values.
This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
@@ -251,24 +255,31 @@ def frame(x,
if pad_end:
num_extends = _pad_end_length(num_timesteps, frame_step, frame_length)
x = jnp.pad(
- x, ((0, 0), (0, num_extends), (0, 0)),
- 'constant',
- constant_values=pad_value)
+ x,
+ ((0, 0), (0, num_extends), (0, 0)),
+ 'constant',
+ constant_values=pad_value,
+ )
flat_y = jax.lax.conv_general_dilated_patches(
- x, (frame_length,), (frame_step,),
- 'VALID',
- dimension_numbers=('NTC', 'OIT', 'NTC'))
+ x,
+ (frame_length,),
+ (frame_step,),
+ 'VALID',
+ dimension_numbers=('NTC', 'OIT', 'NTC'),
+ )
ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length))
return ret.transpose((0, 1, 3, 2))
-def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
- num_spectrogram_bins: int = 129,
- sample_rate: Union[int, float] = 8000,
- lower_edge_hertz: Union[int, float] = 125.0,
- upper_edge_hertz: Union[int, float] = 3800.0,
- dtype: Any = jnp.float32):
+def linear_to_mel_weight_matrix(
+ num_mel_bins: int = 20,
+ num_spectrogram_bins: int = 129,
+ sample_rate: Union[int, float] = 8000,
+ lower_edge_hertz: Union[int, float] = 125.0,
+ upper_edge_hertz: Union[int, float] = 3800.0,
+ dtype: Any = jnp.float32,
+):
r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`.
Args:
@@ -300,23 +311,29 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
if num_mel_bins <= 0:
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
if lower_edge_hertz < 0.0:
- raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
- lower_edge_hertz)
+ raise ValueError(
+ 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz
+ )
if lower_edge_hertz >= upper_edge_hertz:
- raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
- (lower_edge_hertz, upper_edge_hertz))
+ raise ValueError(
+ 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f'
+ % (lower_edge_hertz, upper_edge_hertz)
+ )
if sample_rate <= 0.0:
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
if upper_edge_hertz > sample_rate / 2:
- raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
- 'frequency (sample_rate / 2). Got %s for sample_rate: %s' %
- (upper_edge_hertz, sample_rate))
+ raise ValueError(
+ 'upper_edge_hertz must not be larger than the Nyquist '
+ 'frequency (sample_rate / 2). Got %s for sample_rate: %s'
+ % (upper_edge_hertz, sample_rate)
+ )
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = jnp.linspace(
- 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:]
+ 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype
+ )[bands_to_zero:]
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis]
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
@@ -324,10 +341,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
# num_mel_bins + 2 pieces.
edges = jnp.linspace(
- _hertz_to_mel(lower_edge_hertz),
- _hertz_to_mel(upper_edge_hertz),
- num_mel_bins + 2,
- dtype=dtype)
+ _hertz_to_mel(lower_edge_hertz),
+ _hertz_to_mel(upper_edge_hertz),
+ num_mel_bins + 2,
+ dtype=dtype,
+ )
# Split the triples up and reshape them into [1, num_mel_bins] tensors.
lower_edge_mel = edges[:-2][jnp.newaxis, :]
@@ -337,9 +355,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the mel domain, not Hertz.
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
- center_mel - lower_edge_mel)
+ center_mel - lower_edge_mel
+ )
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
- upper_edge_mel - center_mel)
+ upper_edge_mel - center_mel
+ )
# Intersect the line segments with each other and zero.
mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes))
@@ -366,23 +386,26 @@ def _hanning_greco(win_support, frame_size, dtype):
"""
if frame_size < win_support:
raise ValueError(
- 'Provided frame_size = {} is lower than win_support = {}'.format(
- frame_size, win_support))
+ 'Provided frame_size = {} is lower than win_support = {}'.format(
+ frame_size, win_support
+ )
+ )
arg = jnp.pi * 2.0 / (win_support)
- hann = 0.5 - (0.5 * jnp.cos(arg *
- (jnp.arange(win_support, dtype=dtype) + 0.5)))
+ hann = 0.5 - (
+ 0.5 * jnp.cos(arg * (jnp.arange(win_support, dtype=dtype) + 0.5))
+ )
zero_size = frame_size - win_support
return jnp.pad(hann, [(0, zero_size)])
def _next_pow_of_two(x: Union[int, float]) -> int:
- return int(2**np.ceil(np.log2(x)))
+ return int(2 ** np.ceil(np.log2(x)))
class SpectrogramFrontend(nn.Module):
- """Layer to convert input audio signals from time domain to frequency domain.
- """
+ """Layer to convert input audio signals from time domain to frequency domain."""
+
config: LibrispeechPreprocessingConfig = None
input_scale_factor: float = 1.0
output_log: bool = False
@@ -390,8 +413,9 @@ class SpectrogramFrontend(nn.Module):
def setup(self) -> None:
p = self.config
self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0))
- self._frame_size = int(round(
- p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph
+ self._frame_size = (
+ int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1
+ ) # +1 for the preemph
# TF-version has maximum of 512, but it's not always necessary
self.fft_size = _next_pow_of_two(self._frame_size)
@@ -421,32 +445,39 @@ def f(frame_size, dtype):
def _apply_preemphasis(self, framed_signal):
p = self.config
if p.preemph_htk_flavor:
- return jnp.concatenate([
- framed_signal[:, :, :1, :] * (1. - p.preemph),
- (framed_signal[:, :, 1:-1, :] -
- p.preemph * framed_signal[:, :, :-2, :])
- ],
- axis=2)
+ return jnp.concatenate(
+ [
+ framed_signal[:, :, :1, :] * (1.0 - p.preemph),
+ (
+ framed_signal[:, :, 1:-1, :]
+ - p.preemph * framed_signal[:, :, :-2, :]
+ ),
+ ],
+ axis=2,
+ )
else:
- return (framed_signal[:, :, 1:, :] -
- p.preemph * framed_signal[:, :, :-1, :])
+ return (
+ framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :]
+ )
def fprop_paddings(self, input_paddings):
p = self.config
if p.pad_end:
- num_extends = _pad_end_length(input_paddings.shape[1],
- self._frame_step,
- self._frame_size)
+ num_extends = _pad_end_length(
+ input_paddings.shape[1], self._frame_step, self._frame_size
+ )
input_paddings = jnp.pad(
- input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0)
+ input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0
+ )
return jax.lax.reduce_window(
- input_paddings,
- init_value=1.0,
- computation=jax.lax.min,
- window_dimensions=[1, self._frame_size],
- window_strides=[1, self._frame_step],
- padding='valid')
+ input_paddings,
+ init_value=1.0,
+ computation=jax.lax.min,
+ window_dimensions=[1, self._frame_size],
+ window_strides=[1, self._frame_step],
+ padding='valid',
+ )
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
@@ -469,7 +500,8 @@ def __call__(self, inputs, input_paddings):
pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor
framed_signal = frame(
- pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end)
+ pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end
+ )
if p.preemph != 0.0:
preemphasized = self._apply_preemphasis(framed_signal)
@@ -477,8 +509,10 @@ def __call__(self, inputs, input_paddings):
preemphasized = framed_signal[..., :-1, :]
if p.noise_scale > 0.0:
- noise_signal = jax.random.normal(self.next_prng_key(),
- preemphasized.shape) * p.noise_scale
+ noise_signal = (
+ jax.random.normal(self.next_prng_key(), preemphasized.shape)
+ * p.noise_scale
+ )
else:
noise_signal = jnp.zeros(preemphasized.shape)
@@ -501,8 +535,8 @@ def __call__(self, inputs, input_paddings):
class MelFilterbankFrontend(nn.Module):
- """Layer to compute log mel spectograms from input audio signals.
- """
+ """Layer to compute log mel spectograms from input audio signals."""
+
config: LibrispeechPreprocessingConfig = None
use_divide_stream: bool = True
per_bin_mean: Optional[float] = None
@@ -513,7 +547,8 @@ def setup(self):
input_scale_factor = 2**-15 if self.use_divide_stream else 1.0
self.stft = SpectrogramFrontend(
- p, input_scale_factor=input_scale_factor, output_log=False)
+ p, input_scale_factor=input_scale_factor, output_log=False
+ )
if self.per_bin_mean is None:
per_bin_mean = [0.0] * p.num_bins
@@ -526,9 +561,11 @@ def setup(self):
per_bin_stddev = self.per_bin_stddev
self._normalizer_mean = jnp.array(per_bin_mean)[
- jnp.newaxis, jnp.newaxis, :, jnp.newaxis]
+ jnp.newaxis, jnp.newaxis, :, jnp.newaxis
+ ]
self._normalizer_stddev = jnp.array(per_bin_stddev)[
- jnp.newaxis, jnp.newaxis, :, jnp.newaxis]
+ jnp.newaxis, jnp.newaxis, :, jnp.newaxis
+ ]
@nn.compact
def __call__(self, inputs, input_paddings):
@@ -537,18 +574,21 @@ def __call__(self, inputs, input_paddings):
spect, spect_paddings = self.stft(inputs, input_paddings)
mel_weights = linear_to_mel_weight_matrix(
- num_mel_bins=p.num_bins,
- num_spectrogram_bins=spect.shape[2],
- sample_rate=p.sample_rate,
- lower_edge_hertz=p.lower_edge_hertz,
- upper_edge_hertz=p.upper_edge_hertz)
+ num_mel_bins=p.num_bins,
+ num_spectrogram_bins=spect.shape[2],
+ sample_rate=p.sample_rate,
+ lower_edge_hertz=p.lower_edge_hertz,
+ upper_edge_hertz=p.upper_edge_hertz,
+ )
mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect)
logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor))
normalized_logmel_spectrogram = (
- (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev)
+ logmel_spectrogram - self._normalizer_mean
+ ) / self._normalizer_stddev
- normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram,
- -1)
+ normalized_logmel_spectrogram = jnp.squeeze(
+ normalized_logmel_spectrogram, -1
+ )
return normalized_logmel_spectrogram, spect_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 593d463c3..9fc2e39ef 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -16,34 +16,32 @@
import math
from typing import Any, List, Optional
-from flax import linen as nn
-from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
+from flax import linen as nn
+from flax import struct
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- librispeech_preprocessor as preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- spectrum_augmenter
+from algoperf.jax_utils import Dropout
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ librispeech_preprocessor as preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ spectrum_augmenter,
+)
+
+DROPOUT_RATE = 0.1
@struct.dataclass
class ConformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
dtype: Any = jnp.float32
encoder_dim: int = 512
num_attention_heads: int = 8
num_encoder_layers: int = 4
- attention_dropout_rate: float = 0.0
- # If None, defaults to 0.1.
- attention_residual_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.0.
- conv_residual_dropout_rate: Optional[float] = 0.0
- feed_forward_dropout_rate: float = 0.0
- # If None, defaults to 0.1.
- feed_forward_residual_dropout_rate: Optional[float] = 0.1
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
@@ -53,8 +51,6 @@ class ConformerConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
@@ -73,6 +69,7 @@ class LayerNorm(nn.Module):
zeros, this differs from default flax implementation of multiplying by scale
and initializing to ones.
"""
+
dim: int = 0
epsilon: float = 1e-6
@@ -86,7 +83,7 @@ def __call__(self, inputs):
var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
- normed_inputs *= (1 + self.scale)
+ normed_inputs *= 1 + self.scale
normed_inputs += self.bias
return normed_inputs
@@ -99,39 +96,41 @@ class Subsample(nn.Module):
encoder_dim: model dimension of conformer.
input_dropout_rate: dropout rate for inputs.
"""
+
encoder_dim: int = 0
- input_dropout_rate: float = 0.0
@nn.compact
- def __call__(self, inputs, input_paddings, train):
+ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
output_paddings = input_paddings
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
- input_channels=1, output_channels=self.encoder_dim)(
- outputs, output_paddings)
+ input_channels=1, output_channels=self.encoder_dim
+ )(outputs, output_paddings)
outputs, output_paddings = Conv2dSubsampling(
- input_channels=self.encoder_dim,
- output_channels=self.encoder_dim)(outputs, output_paddings)
+ input_channels=self.encoder_dim, output_channels=self.encoder_dim
+ )(outputs, output_paddings)
batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
outputs = jnp.reshape(
- outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
+ outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)
+ )
outputs = nn.Dense(
- self.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ self.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)(
- seq_length=outputs.shape[1])
+ seq_length=outputs.shape[1]
+ )
- outputs = nn.Dropout(
- rate=self.input_dropout_rate, deterministic=not train)(
- outputs)
+ outputs = Dropout(rate=dropout_rate, deterministic=not train)(
+ outputs, rate=dropout_rate
+ )
return outputs, output_paddings
@@ -143,6 +142,7 @@ class Conv2dSubsampling(nn.Module):
2) Also performs strided convolution over input_paddings to return the correct
paddings for downstream layers.
"""
+
input_channels: int = 0
output_channels: int = 0
filter_stride: List[int] = (2, 2)
@@ -150,24 +150,26 @@ class Conv2dSubsampling(nn.Module):
def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
- self.kernel = self.param('kernel',
- nn.initializers.xavier_uniform(),
- self.filter_shape)
+ self.kernel = self.param(
+ 'kernel', nn.initializers.xavier_uniform(), self.filter_shape
+ )
self.bias = self.param(
- 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
+ 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels
+ )
@nn.compact
def __call__(self, inputs, paddings):
# Computing strided convolution to subsample inputs.
feature_group_count = inputs.shape[3] // self.filter_shape[2]
outputs = jax.lax.conv_general_dilated(
- lhs=inputs,
- rhs=self.kernel,
- window_strides=self.filter_stride,
- padding=self.padding,
- rhs_dilation=(1, 1),
- dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
- feature_group_count=feature_group_count)
+ lhs=inputs,
+ rhs=self.kernel,
+ window_strides=self.filter_stride,
+ padding=self.padding,
+ rhs_dilation=(1, 1),
+ dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
+ feature_group_count=feature_group_count,
+ )
outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
outputs = nn.relu(outputs)
@@ -178,64 +180,64 @@ def __call__(self, inputs, paddings):
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = jax.lax.conv_general_dilated(
- lhs=paddings[:, :, None],
- rhs=jnp.ones([1, 1, 1]),
- window_strides=self.filter_stride[:1],
- padding=[(0, pad_len)],
- dimension_numbers=('NHC', 'HIO', 'NHC'))
+ lhs=paddings[:, :, None],
+ rhs=jnp.ones([1, 1, 1]),
+ window_strides=self.filter_stride[:1],
+ padding=[(0, pad_len)],
+ dimension_numbers=('NHC', 'HIO', 'NHC'),
+ )
out_padding = jnp.squeeze(out_padding, axis=-1)
# Mask outputs by correct paddings to ensure padded elements in inputs map
# to padded value in outputs.
- outputs = outputs * \
- (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
+ outputs = outputs * (
+ 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)
+ )
return outputs, out_padding
class FeedForwardModule(nn.Module):
- """Feedforward block of conformer layer.
- """
+ """Feedforward block of conformer layer."""
+
config: ConformerConfig
@nn.compact
- def __call__(self, inputs, padding_mask=None, train=False):
+ def __call__(
+ self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE
+ ):
config = self.config
-
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
inputs = nn.Dense(
- config.encoder_dim * config.feed_forward_expansion_factor,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.encoder_dim * config.feed_forward_expansion_factor,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
- inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
- inputs, deterministic=not train)
+ inputs = Dropout(rate=dropout_rate)(
+ inputs, deterministic=not train, rate=dropout_rate
+ )
inputs = inputs * padding_mask
inputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
inputs = inputs * padding_mask
- if config.feed_forward_residual_dropout_rate is None:
- feed_forward_residual_dropout_rate = 0.1
- else:
- feed_forward_residual_dropout_rate = (
- config.feed_forward_residual_dropout_rate)
- inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)(
- inputs, deterministic=not train)
+ inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train)
return inputs
@@ -247,6 +249,7 @@ class AddPositionalEmbedding(nn.Module):
max_len: maximum possible length for the input
posemb_init: positional embedding initializer
"""
+
min_timescale: int = 1
max_timescale: int = 10_000
embedding_dim: int = 512
@@ -255,21 +258,23 @@ class AddPositionalEmbedding(nn.Module):
def __call__(self, seq_length):
position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
num_timescales = self.embedding_dim // 2
- log_timescale_increment = (
- math.log(float(self.max_timescale) / float(self.min_timescale)) /
- jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1))
+ log_timescale_increment = math.log(
+ float(self.max_timescale) / float(self.min_timescale)
+ ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
inv_timescales = self.min_timescale * jnp.exp(
- jnp.arange(num_timescales, dtype=jnp.float32) *
- -log_timescale_increment)
+ jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
+ )
scaled_time = (
- position[:, :, jnp.newaxis] *
- inv_timescales[jnp.newaxis, jnp.newaxis, :])
- signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)],
- axis=2).astype(jnp.float32)
+ position[:, :, jnp.newaxis] * inv_timescales[jnp.newaxis, jnp.newaxis, :]
+ )
+ signal = jnp.concatenate(
+ [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2
+ ).astype(jnp.float32)
# Force usage of `np` rather than `jnp` to compute static values at trace
# time.
- signal = jnp.pad(signal,
- [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
+ signal = jnp.pad(
+ signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]
+ )
return signal
@@ -277,6 +282,7 @@ def __call__(self, seq_length):
# https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201
class QueryScaler(nn.Module):
"""A layer to scale individual dims of the query attention matrix."""
+
dim: int = 0
def setup(self):
@@ -286,8 +292,10 @@ def setup(self):
def __call__(self, inputs):
inputs_shape = inputs.shape
if inputs_shape[-1] != self.dim:
- raise ValueError('QueryScaler expects inputs to have'
- ' same last dimension as scaling param.')
+ raise ValueError(
+ 'QueryScaler expects inputs to have'
+ ' same last dimension as scaling param.'
+ )
# 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
# can avoid unnecessary XLA op fusion mess on TPU.
@@ -302,18 +310,20 @@ def __call__(self, inputs):
# Modifying flax linen default dot product attention function to add
# query scaling, reference to original function here :
# https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121
-def dot_product_attention(query,
- key,
- value,
- bias=None,
- mask=None,
- broadcast_dropout=True,
- dropout_rng=None,
- dropout_rate=0.,
- deterministic=False,
- dtype=jnp.float32,
- precision=None,
- temperature=1.0):
+def dot_product_attention(
+ query,
+ key,
+ value,
+ bias=None,
+ mask=None,
+ broadcast_dropout=True,
+ dropout_rng=None,
+ dropout_rate=0.0,
+ deterministic=False,
+ dtype=jnp.float32,
+ precision=None,
+ temperature=1.0,
+):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
@@ -352,29 +362,35 @@ def dot_product_attention(query,
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
- 'q, k, v batch dims must match.')
+ 'q, k, v batch dims must match.'
+ )
assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
- 'q, k, v num_heads must match.')
+ 'q, k, v num_heads must match.'
+ )
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# compute attention weights
query = QueryScaler(dim=query.shape[-1])(query)
attn_weights = nn.attention.dot_product_attention_weights(
- query,
- key,
- bias,
- mask,
- broadcast_dropout,
- dropout_rng,
- dropout_rate,
- deterministic,
- dtype,
- precision)
+ query,
+ key,
+ bias,
+ mask,
+ broadcast_dropout,
+ dropout_rng,
+ dropout_rate,
+ deterministic,
+ dtype,
+ precision,
+ )
# return weighted sum over values for each query position
- return jnp.einsum(
- '...hqk,...khd->...qhd', attn_weights, value,
- precision=precision) * temperature
+ return (
+ jnp.einsum(
+ '...hqk,...khd->...qhd', attn_weights, value, precision=precision
+ )
+ * temperature
+ )
class MultiHeadedSelfAttention(nn.Module):
@@ -386,39 +402,39 @@ class MultiHeadedSelfAttention(nn.Module):
Note: this attention implementation uses a learned scale parameter to scale
query matrix before passing it to flax attention module.
"""
+
config: ConformerConfig = None
@nn.compact
- def __call__(self, inputs, paddings, train):
+ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
config = self.config
+
mask_paddings = 1 - paddings
attention_mask = nn.make_attention_mask(
- mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)
+ mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32
+ )
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
attention_fn = functools.partial(
- dot_product_attention, temperature=config.attention_temperature)
+ dot_product_attention, temperature=config.attention_temperature
+ )
result = nn.MultiHeadDotProductAttention(
- num_heads=config.num_attention_heads,
- qkv_features=config.encoder_dim,
- decode=False,
- dtype=config.dtype,
- kernel_init=nn.initializers.xavier_uniform(),
- bias_init=nn.initializers.zeros,
- use_bias=True,
- broadcast_dropout=False,
- attention_fn=attention_fn,
- dropout_rate=config.attention_dropout_rate,
- deterministic=not train)(
- inputs_q=inputs, mask=attention_mask)
-
- if config.attention_residual_dropout_rate is None:
- attention_residual_dropout_rate = 0.1
- else:
- attention_residual_dropout_rate = config.attention_residual_dropout_rate
- result = nn.Dropout(
- rate=attention_residual_dropout_rate, deterministic=not train)(
- result)
+ num_heads=config.num_attention_heads,
+ qkv_features=config.encoder_dim,
+ decode=False,
+ dtype=config.dtype,
+ kernel_init=nn.initializers.xavier_uniform(),
+ bias_init=nn.initializers.zeros,
+ use_bias=True,
+ broadcast_dropout=False,
+ attention_fn=attention_fn,
+ dropout_rate=dropout_rate,
+ deterministic=not train,
+ )(inputs_q=inputs, mask=attention_mask)
+
+ result = Dropout(rate=dropout_rate, deterministic=not train)(
+ result, rate=dropout_rate
+ )
return result
@@ -435,30 +451,27 @@ class BatchNorm(nn.Module):
and the corresponding defaults for momentum and epsilon have been copied over
from lingvo.
"""
+
config: ConformerConfig
def setup(self):
dim = self.config.encoder_dim
dtype = self.config.dtype
- self.ra_mean = self.variable('batch_stats',
- 'mean',
- lambda s: jnp.zeros(s, dtype),
- dim)
- self.ra_var = self.variable('batch_stats',
- 'var',
- lambda s: jnp.ones(s, dtype),
- dim)
+ self.ra_mean = self.variable(
+ 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim
+ )
+ self.ra_var = self.variable(
+ 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim
+ )
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- update_batch_norm,
- use_running_average_bn):
+ def __call__(
+ self, inputs, input_paddings, update_batch_norm, use_running_average_bn
+ ):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))
@@ -475,23 +488,25 @@ def __call__(self,
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
- jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
+ jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True
+ )
count_v = jnp.maximum(count_v, 1.0)
mean = sum_v / count_v
sum_vv = jnp.sum(
- (inputs - mean) * (inputs - mean) * mask,
- axis=reduce_over_dims,
- keepdims=True)
+ (inputs - mean) * (inputs - mean) * mask,
+ axis=reduce_over_dims,
+ keepdims=True,
+ )
var = sum_vv / count_v
if update_batch_norm:
- self.ra_mean.value = momentum * \
- self.ra_mean.value + (1 - momentum) * mean
- self.ra_var.value = momentum * \
- self.ra_var.value + (1 - momentum) * var
+ self.ra_mean.value = (
+ momentum * self.ra_mean.value + (1 - momentum) * mean
+ )
+ self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
bn_output = (inputs - mean) * inv + self.beta
@@ -520,67 +535,68 @@ class ConvolutionBlock(nn.Module):
|
output
"""
+
config: ConformerConfig
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average_bn):
+ def __call__(
+ self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average_bn,
+ dropout_rate=DROPOUT_RATE,
+ ):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
input_gated1 = nn.Dense(
- config.encoder_dim,
- kernel_init=nn.initializers.xavier_uniform(),
- use_bias=True)(
- inputs)
+ config.encoder_dim,
+ kernel_init=nn.initializers.xavier_uniform(),
+ use_bias=True,
+ )(inputs)
input_gated2 = nn.Dense(
- config.encoder_dim,
- kernel_init=nn.initializers.xavier_uniform(),
- use_bias=True)(
- inputs)
+ config.encoder_dim,
+ kernel_init=nn.initializers.xavier_uniform(),
+ use_bias=True,
+ )(inputs)
inputs = input_gated1 * jax.nn.sigmoid(input_gated2)
inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1))
inputs = nn.Conv(
- features=config.encoder_dim,
- kernel_size=(config.convolution_kernel_size,),
- strides=(1,),
- padding='SAME',
- feature_group_count=config.encoder_dim,
- use_bias=False,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
-
- inputs = BatchNorm(config)(inputs,
- input_paddings,
- update_batch_norm,
- use_running_average_bn)
+ features=config.encoder_dim,
+ kernel_size=(config.convolution_kernel_size,),
+ strides=(1,),
+ padding='SAME',
+ feature_group_count=config.encoder_dim,
+ use_bias=False,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
+
+ inputs = BatchNorm(config)(
+ inputs, input_paddings, update_batch_norm, use_running_average_bn
+ )
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = nn.Dense(
- config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.encoder_dim, kernel_init=nn.initializers.xavier_uniform()
+ )(inputs)
- if config.conv_residual_dropout_rate is None:
- conv_residual_dropout_rate = 0.0
- else:
- conv_residual_dropout_rate = config.conv_residual_dropout_rate
- inputs = nn.Dropout(
- rate=conv_residual_dropout_rate, deterministic=not train)(
- inputs)
+ inputs = Dropout(rate=dropout_rate, deterministic=not train)(
+ inputs, rate=dropout_rate
+ )
return inputs
@@ -597,34 +613,42 @@ class ConformerBlock(nn.Module):
y = layer_norm(x)
"""
+
config: ConformerConfig
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average):
+ def __call__(
+ self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average,
+ dropout_rate=DROPOUT_RATE,
+ ):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train)
+ inputs, padding_mask, train, dropout_rate
+ )
inputs = inputs + MultiHeadedSelfAttention(config=self.config)(
- inputs, input_paddings, train)
-
- inputs = inputs + \
- ConvolutionBlock(config)(inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average
- )
+ inputs, input_paddings, train, dropout_rate=dropout_rate
+ )
+
+ inputs = inputs + ConvolutionBlock(config)(
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average,
+ dropout_rate,
+ )
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train)
+ inputs, padding_mask, train, dropout_rate
+ )
if config.use_post_layer_norm:
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
@@ -639,26 +663,30 @@ class Conformer(nn.Module):
for each time step. The output is then fed into a CTC loss which eliminates
the need for alignment with targets.
"""
+
config: ConformerConfig
def setup(self):
self.specaug = spectrum_augmenter.SpecAug(
- freq_mask_count=self.config.freq_mask_count,
- freq_mask_max_bins=self.config.freq_mask_max_bins,
- time_mask_count=self.config.time_mask_count,
- time_mask_max_frames=self.config.time_mask_max_frames,
- time_mask_max_ratio=self.config.time_mask_max_ratio,
- time_masks_per_frame=self.config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=self.config
- .use_dynamic_time_mask_max_frames)
+ freq_mask_count=self.config.freq_mask_count,
+ freq_mask_max_bins=self.config.freq_mask_max_bins,
+ time_mask_count=self.config.time_mask_count,
+ time_mask_max_frames=self.config.time_mask_max_frames,
+ time_mask_max_ratio=self.config.time_mask_max_ratio,
+ time_masks_per_frame=self.config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=self.config.use_dynamic_time_mask_max_frames,
+ )
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm: Optional[bool] = None,
- use_running_average_bn: Optional[bool] = None):
+ def __call__(
+ self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm: Optional[bool] = None,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: float = DROPOUT_RATE,
+ ):
config = self.config
outputs = inputs
@@ -675,38 +703,35 @@ def __call__(self,
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
preprocessing_config,
per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(
- outputs, output_paddings)
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )(outputs, output_paddings)
# Ablate random parts of input along temporal and frequency dimension
# following the specaug procedure in https://arxiv.org/abs/1904.08779.
if train and config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- # Subsample input by a factor of 4 by performing strided convolutions.
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
outputs, output_paddings = Subsample(
- encoder_dim=config.encoder_dim,
- input_dropout_rate=input_dropout_rate)(
- outputs, output_paddings, train)
+ encoder_dim=config.encoder_dim,
+ )(outputs, output_paddings, train, dropout_rate=dropout_rate)
# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
- outputs = ConformerBlock(config)(outputs,
- output_paddings,
- train,
- update_batch_norm,
- use_running_average_bn)
+ outputs = ConformerBlock(config)(
+ outputs,
+ output_paddings,
+ train,
+ update_batch_norm,
+ use_running_average_bn,
+ dropout_rate,
+ )
outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
outputs = nn.Dense(
- config.vocab_size,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ config.vocab_size,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py
index 2a6f73d4d..d9c1e301b 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py
@@ -17,6 +17,7 @@ class SpecAug(nn.Module):
This is an essential component in speech recognition models that helps achieve
better word error rates.
"""
+
freq_mask_count: int = 2
freq_mask_max_bins: int = 27
time_mask_count: int = 10
@@ -28,26 +29,30 @@ class SpecAug(nn.Module):
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
- def _get_mask(self,
- batch_size,
- choose_range,
- mask_size,
- max_length=None,
- masks_per_frame=0.0,
- multiplicity=1,
- max_ratio=1.0):
+ def _get_mask(
+ self,
+ batch_size,
+ choose_range,
+ mask_size,
+ max_length=None,
+ masks_per_frame=0.0,
+ multiplicity=1,
+ max_ratio=1.0,
+ ):
# Sample lengths for multiple masks.
if max_length and max_length > 0:
max_length = jnp.tile(max_length, (batch_size,))
else:
max_length = choose_range * max_ratio
masked_portion = jax.random.uniform(
- key=self.next_prng_key(),
- shape=(batch_size, multiplicity),
- minval=0.0,
- maxval=1.0)
- masked_frame_size = jnp.einsum('b,bm->bm', max_length,
- masked_portion).astype(jnp.int32)
+ key=self.next_prng_key(),
+ shape=(batch_size, multiplicity),
+ minval=0.0,
+ maxval=1.0,
+ )
+ masked_frame_size = jnp.einsum(
+ 'b,bm->bm', max_length, masked_portion
+ ).astype(jnp.int32)
# Make sure the sampled length was smaller than max_ratio * length_bound.
# Note that sampling in this way was biased
# (shorter sequence may over-masked.)
@@ -57,7 +62,8 @@ def _get_mask(self,
# Choose starting point.
random_start = jax.random.uniform(
- key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0)
+ key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0
+ )
start_with_in_valid_range = random_start * (choose_range - length + 1)
start = start_with_in_valid_range.astype(jnp.int32)
@@ -78,11 +84,13 @@ def _get_mask(self,
# Sum masks with appropriate multiplicity.
if masks_per_frame > 0:
multiplicity_weights = jnp.tile(
- jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
- [batch_size, 1])
+ jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
+ [batch_size, 1],
+ )
multiplicity_tensor = masks_per_frame * choose_range
- multiplicity_weights = (multiplicity_weights <
- multiplicity_tensor).astype(jnp.int32)
+ multiplicity_weights = (
+ multiplicity_weights < multiplicity_tensor
+ ).astype(jnp.int32)
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
else:
pre_mask = jnp.einsum('bmt->bt', pre_mask)
@@ -98,8 +106,9 @@ def _time_mask(self, inputs, length):
max_ratio = self.time_mask_max_ratio
# If maximum mask length is zero, do nothing.
- if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or
- max_ratio <= 0.0):
+ if (
+ time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames
+ ) or max_ratio <= 0.0:
return inputs
if multiplicity == 0:
return inputs
@@ -111,13 +120,14 @@ def _time_mask(self, inputs, length):
time_mask_max_frames = None
# Create masks in time direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=length,
- mask_size=time_length,
- max_length=time_mask_max_frames,
- masks_per_frame=self.time_masks_per_frame,
- multiplicity=multiplicity,
- max_ratio=max_ratio)
+ batch_size,
+ choose_range=length,
+ mask_size=time_length,
+ max_length=time_mask_max_frames,
+ masks_per_frame=self.time_masks_per_frame,
+ multiplicity=multiplicity,
+ max_ratio=max_ratio,
+ )
outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays)
return outputs
@@ -136,13 +146,14 @@ def _frequency_mask(self, inputs):
choose_range = jnp.tile(num_freq, (batch_size,))
# Create masks in frequency direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=choose_range,
- mask_size=num_freq,
- max_length=freq_mask_max_bins,
- masks_per_frame=0.0,
- multiplicity=multiplicity,
- max_ratio=1.0)
+ batch_size,
+ choose_range=choose_range,
+ mask_size=num_freq,
+ max_length=freq_mask_max_bins,
+ masks_per_frame=0.0,
+ multiplicity=multiplicity,
+ max_ratio=1.0,
+ )
outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays)
return outputs
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 39012a20d..819e57a69 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -2,31 +2,28 @@
import math
from typing import Dict, Iterator, Optional, Tuple
-from flax import jax_utils
-from flax.core import pop
import flax.linen as nn
import jax
-from jax import lax
import jax.numpy as jnp
import numpy as np
import optax
import torch
+from flax import jax_utils
+from flax.core import pop
+from jax import lax
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.librispeech_conformer import metrics
-from algoperf.workloads.librispeech_conformer import workload
-from algoperf.workloads.librispeech_conformer.input_pipeline import \
- LibriSpeechDataset
+from algoperf import data_utils, param_utils, spec
+from algoperf.workloads.librispeech_conformer import metrics, workload
+from algoperf.workloads.librispeech_conformer.input_pipeline import (
+ LibriSpeechDataset,
+)
from algoperf.workloads.librispeech_conformer.librispeech_jax import models
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
-
- def __init__(self,
- tokenizer_vocab_path: Optional[str] = None,
- use_specaug: bool = True) -> None:
+ def __init__(
+ self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True
+ ) -> None:
super().__init__()
self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path)
self.use_specaug = use_specaug
@@ -38,7 +35,8 @@ def __init__(self,
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -58,13 +56,12 @@ def attention_temperature(self) -> float:
return 1.0
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ self,
+ rng: spec.RandomState,
+ ) -> spec.ModelInitState:
"""Conformer model init function.
- Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as
+ Here we use dropout_rate as *_residual_dropout_rate, and for
input_dropout_rate.
"""
if self.use_gelu:
@@ -72,24 +69,22 @@ def init_model_fn(
else:
activation_function_name = 'swish'
model_config = models.ConformerConfig(
- attention_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=aux_dropout_rate,
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name)
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name,
+ )
+
self._model = models.Conformer(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
- params_rng, dropout_rng = jax.random.split(rng, 2)
- variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
- *fake_input_batch)
+ params_rng, _ = jax.random.split(rng, 2)
+ variables = model_init_fn({'params': params_rng}, *fake_input_batch)
- model_state, params = pop(variables, "params")
+ model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -101,47 +96,52 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: Optional[float] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=True,
- rngs={'dropout' : rng},
- mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn)
+ variables,
+ inputs,
+ input_paddings,
+ train=True,
+ rngs={'dropout': rng},
+ mutable=['batch_stats'],
+ use_running_average_bn=use_running_average_bn,
+ dropout_rate=dropout_rate,
+ )
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=False,
- mutable=False,
- use_running_average_bn=use_running_average_bn)
+ variables,
+ inputs,
+ input_paddings,
+ train=False,
+ mutable=False,
+ use_running_average_bn=use_running_average_bn,
+ )
return (logits, logit_paddings), model_state
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del data_rng
del cache
del repeat_final_dataset
@@ -160,38 +160,41 @@ def _build_input_queue(
ds = LibriSpeechDataset(split=split, data_dir=data_dir)
dataloader = data_utils.cycle(
- torch.utils.data.DataLoader(
- ds,
- batch_size=global_batch_size,
- shuffle=train,
- sampler=None,
- num_workers=4,
- prefetch_factor=10,
- pin_memory=False,
- drop_last=train,
- ))
+ torch.utils.data.DataLoader(
+ ds,
+ batch_size=global_batch_size,
+ shuffle=train,
+ sampler=None,
+ num_workers=4,
+ prefetch_factor=10,
+ pin_memory=False,
+ drop_last=train,
+ )
+ )
for batch in iter(dataloader):
inputs, input_paddings = batch['inputs']
targets, target_paddings = batch['targets']
numpy_batch = {
- 'inputs': (inputs.numpy(), input_paddings.numpy()),
- 'targets': (targets.numpy(), target_paddings.numpy()),
+ 'inputs': (inputs.numpy(), input_paddings.numpy()),
+ 'targets': (targets.numpy(), target_paddings.numpy()),
}
padded_batch = data_utils.shard_and_maybe_pad_np(
- numpy_batch, padding_value=1.0)
+ numpy_batch, padding_value=1.0
+ )
yield padded_batch
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
- logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
+ logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -202,10 +205,9 @@ def loss_fn(
logits, logit_paddings = logits_batch
targets, target_paddings = label_batch
logprobs = nn.log_softmax(logits)
- per_example_losses = self.ctc_loss(logprobs,
- logit_paddings,
- targets,
- target_paddings)
+ per_example_losses = self.ctc_loss(
+ logprobs, logit_paddings, targets, target_paddings
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -215,23 +217,26 @@ def loss_fn(
n_valid_examples = jnp.maximum(mask_batch.sum(), 1)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
- def ctc_loss(self,
- logits: spec.Tensor,
- logit_paddings: spec.Tensor,
- labels: spec.Tensor,
- label_paddings: spec.Tensor,
- blank_id: int = 0) -> spec.Tensor:
+ def ctc_loss(
+ self,
+ logits: spec.Tensor,
+ logit_paddings: spec.Tensor,
+ labels: spec.Tensor,
+ label_paddings: spec.Tensor,
+ blank_id: int = 0,
+ ) -> spec.Tensor:
return optax.ctc_loss(
- logits=logits,
- logit_paddings=logit_paddings,
- labels=labels,
- label_paddings=label_paddings,
- blank_id=blank_id)
+ logits=logits,
+ logit_paddings=logit_paddings,
+ labels=labels,
+ label_paddings=label_paddings,
+ blank_id=blank_id,
+ )
# Adapted from lingvo's greedy decoding logic here:
# https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138.
@@ -242,21 +247,22 @@ def sequence_mask(self, lengths: spec.Tensor, maxlen: int) -> spec.Tensor:
c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
return c
- def collapse_and_remove_blanks(self,
- labels: spec.Tensor,
- seq_length: spec.Tensor,
- blank_id: int = 0) -> spec.Tensor:
+ def collapse_and_remove_blanks(
+ self, labels: spec.Tensor, seq_length: spec.Tensor, blank_id: int = 0
+ ) -> spec.Tensor:
b, t = labels.shape
# Zap out blank.
blank_mask = 1 - jnp.equal(labels, blank_id)
labels = (labels * blank_mask).astype(labels.dtype)
# Mask labels that don't equal previous label.
- label_mask = jnp.concatenate([
+ label_mask = jnp.concatenate(
+ [
jnp.ones_like(labels[:, :1], dtype=jnp.int32),
jnp.not_equal(labels[:, 1:], labels[:, :-1]),
- ],
- axis=1)
+ ],
+ axis=1,
+ )
# Filter labels that aren't in the original sequence.
maxlen = labels.shape[1]
@@ -292,12 +298,14 @@ def collapse_and_remove_blanks(self,
# Reshape back to square batch.
batch_size = labels.shape[0]
new_shape = [batch_size, new_maxlen]
- return (jnp.reshape(flat, new_shape).astype(labels.dtype),
- new_seq_len.astype(seq_length.dtype))
+ return (
+ jnp.reshape(flat, new_shape).astype(labels.dtype),
+ new_seq_len.astype(seq_length.dtype),
+ )
def greedy_decode(
- self, logits: spec.Tensor,
- logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]:
+ self, logits: spec.Tensor, logit_paddings: spec.Tensor
+ ) -> Tuple[spec.Tensor, spec.Tensor]:
per_frame_max = jnp.argmax(logits, axis=-1)
seqlen = jnp.sum(1.0 - logit_paddings, axis=-1)
hyp, _ = self.collapse_and_remove_blanks(per_frame_max, seqlen, blank_id=0)
@@ -305,45 +313,51 @@ def greedy_decode(
return hyp, hyp_paddings
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def eval_step_pmapped(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
(logits, logit_paddings), _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings)
loss = self.loss_fn(batch['targets'], (logits, logit_paddings))
targets, target_paddings = batch['targets']
return self.metrics_bundle.gather_from_model_output(
- loss_dict=loss,
- decoded=decoded,
- decoded_paddings=decoded_paddings,
- targets=targets,
- target_paddings=target_paddings,
- axis_name='batch')
-
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ loss_dict=loss,
+ decoded=decoded,
+ decoded_paddings=decoded_paddings,
+ targets=targets,
+ target_paddings=target_paddings,
+ axis_name='batch',
+ )
+
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
if model_state is not None and len(model_state) > 0:
@@ -353,15 +367,15 @@ def _eval_model_on_split(self,
num_batches = int(math.ceil(num_examples / global_batch_size))
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- rng, split, data_dir, global_batch_size, num_batches=num_batches)
+ rng, split, data_dir, global_batch_size, num_batches=num_batches
+ )
metrics_report = None
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
- computed_metrics = self.eval_step_pmapped(params,
- eval_batch,
- model_state,
- rng).unreplicate()
+ computed_metrics = self.eval_step_pmapped(
+ params, eval_batch, model_state, rng
+ ).unreplicate()
if metrics_report is None:
metrics_report = computed_metrics
@@ -374,7 +388,8 @@ def _eval_model_on_split(self,
return computed_metrics
def sync_batch_stats(
- self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
+ self, model_state: spec.ModelAuxiliaryState
+ ) -> spec.ModelAuxiliaryState:
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics and
# we average them.
@@ -385,8 +400,8 @@ def sync_batch_stats(
class LibriSpeechConformerAttentionTemperatureWorkload(
- LibriSpeechConformerWorkload):
-
+ LibriSpeechConformerWorkload
+):
@property
def attention_temperature(self) -> float:
return 1.6
@@ -401,7 +416,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return False
@@ -416,7 +430,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):
-
@property
def use_gelu(self) -> bool:
return True
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
index db1e24521..647b8ff0c 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
@@ -2,34 +2,36 @@
https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
"""
+import math
from dataclasses import dataclass
from functools import partial
-import math
from typing import Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from torch.nn import init
-import torch.nn.functional as F
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
- preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
- SpecAug
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import (
+ preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import (
+ SpecAug,
+)
+
+DROPOUT_RATE = 0.1
@dataclass
class ConformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
encoder_dim: int = 512
num_attention_heads: int = 8
num_encoder_layers: int = 4
attention_dropout_rate: float = 0.0
- attention_residual_dropout_rate: float = 0.1
- conv_residual_dropout_rate: float = 0.0
feed_forward_dropout_rate: float = 0.0
- feed_forward_residual_dropout_rate: float = 0.1
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
@@ -39,7 +41,6 @@ class ConformerConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
- input_dropout_rate: float = 0.1
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
@@ -60,7 +61,6 @@ def initialize(m):
class LayerNorm(nn.Module):
-
def __init__(self, dim, epsilon=1e-6):
super().__init__()
self.dim = dim
@@ -74,28 +74,25 @@ def forward(self, x):
class Subsample(nn.Module):
-
- def __init__(self,
- encoder_dim: int = 0,
- input_dropout_rate: float = 0.0,
- num_bins: int = 80):
+ def __init__(self, encoder_dim: int = 0, num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
- self.input_dropout_rate = input_dropout_rate
self.conv1 = Conv2dSubsampling(
- input_channels=1, output_channels=encoder_dim)
+ input_channels=1, output_channels=encoder_dim
+ )
self.conv2 = Conv2dSubsampling(
- input_channels=encoder_dim, output_channels=encoder_dim)
+ input_channels=encoder_dim, output_channels=encoder_dim
+ )
self.linear = nn.Linear(
- in_features=self.encoder_dim * num_bins // 4,
- out_features=self.encoder_dim,
- bias=True)
+ in_features=self.encoder_dim * num_bins // 4,
+ out_features=self.encoder_dim,
+ bias=True,
+ )
self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
- self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -103,24 +100,27 @@ def forward(self, inputs, input_paddings):
outputs, output_paddings = self.conv2(outputs, output_paddings)
batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
- outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
- subsampled_lengths,
- subsampled_dims * channels)
+ outputs = outputs.permute(0, 2, 3, 1).reshape(
+ batch_size, subsampled_lengths, subsampled_dims * channels
+ )
outputs = self.linear(outputs)
outputs = outputs + self.pos_encode(seq_length=outputs.shape[1])
- outputs = self.dropout(outputs)
+ outputs = F.dropout(
+ outputs, dropout_rate, training=self.training, inplace=True
+ )
return outputs, output_paddings
class Conv2dSubsampling(nn.Module):
-
- def __init__(self,
- input_channels: int,
- output_channels: int,
- filter_stride: Tuple[int] = (2, 2),
- padding: str = 'SAME'):
+ def __init__(
+ self,
+ input_channels: int,
+ output_channels: int,
+ filter_stride: Tuple[int] = (2, 2),
+ padding: str = 'SAME',
+ ):
super().__init__()
self.input_channels = input_channels
@@ -131,7 +131,8 @@ def __init__(self,
self.filter_shape = (output_channels, input_channels, 3, 3)
self.kernel = nn.Parameter(
- torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
+ torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))
+ )
self.bias = nn.Parameter(torch.zeros(output_channels))
self.register_buffer('paddings_kernel', torch.ones([1, 1, 1]))
@@ -161,12 +162,13 @@ def forward(self, inputs, paddings):
else:
in_ = inputs
outputs = F.conv2d(
- input=in_,
- weight=self.kernel,
- bias=self.bias,
- stride=self.filter_stride,
- dilation=(1, 1),
- groups=groups)
+ input=in_,
+ weight=self.kernel,
+ bias=self.bias,
+ stride=self.filter_stride,
+ dilation=(1, 1),
+ groups=groups,
+ )
outputs = F.relu(outputs)
@@ -174,42 +176,37 @@ def forward(self, inputs, paddings):
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
padded_paddings = F.pad(
- paddings[:, None, :], (0, pad_len), mode='constant', value=0)
+ paddings[:, None, :], (0, pad_len), mode='constant', value=0
+ )
out_padding = F.conv1d(
- input=padded_paddings,
- weight=self.paddings_kernel,
- stride=self.filter_stride[:1])
+ input=padded_paddings,
+ weight=self.paddings_kernel,
+ stride=self.filter_stride[:1],
+ )
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
return outputs, out_padding
class FeedForwardModule(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
self.config = config
self.ln = LayerNorm(dim=config.encoder_dim)
self.linear1 = nn.Linear(
- in_features=config.encoder_dim,
- out_features=config.encoder_dim * config.feed_forward_expansion_factor,
- bias=True)
+ in_features=config.encoder_dim,
+ out_features=config.encoder_dim * config.feed_forward_expansion_factor,
+ bias=True,
+ )
self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True)
self.linear2 = nn.Linear(
- in_features=config.encoder_dim * config.feed_forward_expansion_factor,
- out_features=config.encoder_dim,
- bias=True)
-
- if config.feed_forward_residual_dropout_rate is None:
- feed_forward_residual_dropout_rate = 0.1
- else:
- feed_forward_residual_dropout_rate = (
- config.feed_forward_residual_dropout_rate)
- self.dropout2 = nn.Dropout(
- p=feed_forward_residual_dropout_rate, inplace=True)
+ in_features=config.encoder_dim * config.feed_forward_expansion_factor,
+ out_features=config.encoder_dim,
+ bias=True,
+ )
- def forward(self, inputs, padding_mask):
+ def forward(self, inputs, padding_mask, dropout_rate):
inputs = self.ln(inputs)
inputs = self.linear1(inputs)
if self.config.activation_function_name == 'swish':
@@ -218,51 +215,58 @@ def forward(self, inputs, padding_mask):
# Use tanh approximation of GELU which is default for jax
activation_fn = partial(F.gelu, approximate='tanh')
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{self.config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{self.config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = self.dropout1(inputs)
inputs = inputs * padding_mask
inputs = self.linear2(inputs)
inputs = inputs * padding_mask
- inputs = self.dropout2(inputs)
+ inputs = F.dropout(
+ inputs, dropout_rate, training=self.training, inplace=True
+ )
return inputs
class AddPositionalEmbedding(nn.Module):
-
- def __init__(self,
- min_timescale: int = 1,
- max_timescale: int = 10_000,
- embedding_dim: int = 512):
+ def __init__(
+ self,
+ min_timescale: int = 1,
+ max_timescale: int = 10_000,
+ embedding_dim: int = 512,
+ ):
super().__init__()
self.min_timescale = min_timescale
self.max_timescale = max_timescale
self.embedding_dim = embedding_dim
num_timescales = self.embedding_dim // 2
log_timescale_increment = math.log(
- float(self.max_timescale) / float(self.min_timescale)) / (
- num_timescales - 1)
- inv_timescales = self.min_timescale * \
- torch.exp(torch.arange(num_timescales, dtype=torch.float32)
- * -log_timescale_increment)
+ float(self.max_timescale) / float(self.min_timescale)
+ ) / (num_timescales - 1)
+ inv_timescales = self.min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float32)
+ * -log_timescale_increment
+ )
self.register_buffer('inv_timescales', inv_timescales[None, None, :])
def forward(self, seq_length):
position = torch.arange(
- end=seq_length, dtype=torch.float32, device=self.inv_timescales.device)
+ end=seq_length, dtype=torch.float32, device=self.inv_timescales.device
+ )
scaled_time = position[None, :, None] * self.inv_timescales
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
if self.embedding_dim % 2:
signal = torch.cat(
- [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2)
+ [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2
+ )
return signal
class QueryScaler(nn.Module):
-
def __init__(self, dim):
super().__init__()
self.dim = dim
@@ -275,12 +279,11 @@ def forward(self, inputs):
class MHSAwithQS(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
self.embed_dim = config.encoder_dim
self.num_heads = config.num_attention_heads
- self.dropout = config.attention_dropout_rate
+ self.attention_dropout_rate = config.attention_dropout_rate
self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)
@@ -292,20 +295,23 @@ def forward(self, inputs, key_padding_mask=None):
q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
- out = F.scaled_dot_product_attention(
+ out = (
+ F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=~key_padding_mask[:, None, None],
- dropout_p=self.dropout,
- ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
+ dropout_p=self.attention_dropout_rate,
+ )
+ .transpose(1, 2)
+ .reshape(batch_size, seq_len, embed_dim)
+ )
out = out * self.attention_temperature
out = self.out_proj(out)
return out
class MultiHeadedSelfAttention(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
@@ -313,24 +319,20 @@ def __init__(self, config: ConformerConfig):
self.ln = LayerNorm(dim=config.encoder_dim)
self.self_attention = MHSAwithQS(config)
- if config.attention_residual_dropout_rate is None:
- attention_residual_dropout_rate = 0.1
- else:
- attention_residual_dropout_rate = config.attention_residual_dropout_rate
- self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True)
- def forward(self, outputs, paddings):
+ def forward(self, outputs, paddings, dropout_rate):
outputs = self.ln(outputs)
outputs = self.self_attention(
- outputs,
- key_padding_mask=paddings == 1,
+ outputs,
+ key_padding_mask=paddings == 1,
+ )
+ outputs = F.dropout(
+ outputs, dropout_rate, training=self.training, inplace=True
)
- outputs = self.dropout(outputs)
return outputs
class BatchNorm(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
running_mean = torch.zeros(config.encoder_dim)
@@ -345,8 +347,8 @@ def __init__(self, config: ConformerConfig):
self.epsilon = config.batch_norm_epsilon
def forward(self, inputs, input_paddings):
- #inputs: NHD
- #padding: NH
+ # inputs: NHD
+ # padding: NH
"""
Alternatively:
inputs[input_paddings==0] = F.batch_norm(
@@ -370,9 +372,11 @@ def forward(self, inputs, input_paddings):
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count
self.running_mean = (1 - self.momentum) * self.running_mean + (
- self.momentum) * mean.detach()
+ self.momentum
+ ) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
- self.momentum) * var.detach()
+ self.momentum
+ ) * var.detach()
else:
mean = self.running_mean
@@ -384,34 +388,31 @@ def forward(self, inputs, input_paddings):
class ConvolutionBlock(nn.Module):
-
def __init__(self, config):
super().__init__()
self.config = config
self.ln = LayerNorm(dim=config.encoder_dim)
self.lin1 = nn.Linear(
- in_features=config.encoder_dim, out_features=config.encoder_dim)
+ in_features=config.encoder_dim, out_features=config.encoder_dim
+ )
self.lin2 = nn.Linear(
- in_features=config.encoder_dim, out_features=config.encoder_dim)
+ in_features=config.encoder_dim, out_features=config.encoder_dim
+ )
self.conv1 = nn.Conv1d(
- in_channels=config.encoder_dim,
- out_channels=config.encoder_dim,
- kernel_size=(config.convolution_kernel_size,),
- stride=(1,),
- padding='same',
- bias=False,
- groups=config.encoder_dim)
+ in_channels=config.encoder_dim,
+ out_channels=config.encoder_dim,
+ kernel_size=(config.convolution_kernel_size,),
+ stride=(1,),
+ padding='same',
+ bias=False,
+ groups=config.encoder_dim,
+ )
self.bn = BatchNorm(config)
self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
- if config.conv_residual_dropout_rate is None:
- conv_residual_dropout_rate = 0.0
- else:
- conv_residual_dropout_rate = config.conv_residual_dropout_rate
- self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
inputs = self.ln(inputs)
inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2))
@@ -427,18 +428,21 @@ def forward(self, inputs, input_paddings):
elif self.config.activation_function_name == 'gelu':
activation_fn = F.gelu
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{self.config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{self.config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = self.lin3(inputs)
- inputs = self.dropout(inputs)
+ inputs = F.dropout(
+ inputs, dropout_rate, training=self.training, inplace=True
+ )
return inputs
class ConformerBlock(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
@@ -450,59 +454,57 @@ def __init__(self, config: ConformerConfig):
if config.use_post_layer_norm:
self.ln = LayerNorm(dim=config.encoder_dim)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
padding_mask = 1 - input_paddings[:, :, None]
- inputs = inputs + 0.5 * self.ff1(inputs, padding_mask)
- inputs = inputs + self.mhsa(inputs, input_paddings)
- inputs = inputs + self.conv(inputs, input_paddings)
- inputs = inputs + 0.5 * self.ff2(inputs, padding_mask)
+ inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate)
+ inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate)
+ inputs = inputs + self.conv(inputs, input_paddings, dropout_rate)
+ inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate)
if self.ln:
inputs = self.ln(inputs)
return inputs
class ConformerEncoderDecoder(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
self.config = config
preprocessing_config = preprocessor.PreprocessorConfig()
self.preprocessor = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )
self.specaug = SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames,
)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
self.subsample = Subsample(
- encoder_dim=config.encoder_dim,
- input_dropout_rate=input_dropout_rate,
- num_bins=preprocessing_config.num_bins)
+ encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins
+ )
self.conformers = nn.ModuleList(
- [ConformerBlock(config) for _ in range(config.num_encoder_layers)])
+ [ConformerBlock(config) for _ in range(config.num_encoder_layers)]
+ )
self.ln = LayerNorm(config.encoder_dim)
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(
+ outputs, output_paddings, dropout_rate
+ )
for conformer in self.conformers:
- outputs = conformer(outputs, output_paddings)
+ outputs = conformer(outputs, output_paddings, dropout_rate)
outputs = self.ln(outputs)
outputs = self.lin(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
index 558a0f796..58dd837dc 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
@@ -2,188 +2,189 @@
https://github.com/google/init2winit/blob/master/init2winit/model_lib/librispeech_preprocessor.py.
"""
-from dataclasses import dataclass
import math
+from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
-from torch import nn
import torch.nn.functional as F
+from torch import nn
# mel spectrum constants.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
LIBRISPEECH_MEAN_VECTOR = [
- -7.6047816276550293,
- -7.1206226348876953,
- -6.8864245414733887,
- -6.8705768585205078,
- -6.9667720794677734,
- -7.1084094047546387,
- -6.9528026580810547,
- -6.783994197845459,
- -6.6195521354675293,
- -6.4876265525817871,
- -6.4120659828186035,
- -6.394047737121582,
- -6.4244871139526367,
- -6.3993711471557617,
- -6.5158271789550781,
- -6.7137999534606934,
- -6.8476877212524414,
- -6.9885001182556152,
- -6.9221386909484863,
- -7.146148681640625,
- -7.2040400505065918,
- -7.0537552833557129,
- -7.3140382766723633,
- -7.1223249435424805,
- -7.30251407623291,
- -7.1212143898010254,
- -7.2425732612609863,
- -7.1730537414550781,
- -7.0979413986206055,
- -7.088747501373291,
- -6.9849910736083984,
- -6.8787732124328613,
- -6.7602753639221191,
- -6.6300945281982422,
- -6.5145769119262695,
- -6.4245057106018066,
- -6.356513500213623,
- -6.31787633895874,
- -6.2660770416259766,
- -6.2468328475952148,
- -6.2821526527404785,
- -6.1908388137817383,
- -6.2484354972839355,
- -6.1472640037536621,
- -6.0924725532531738,
- -6.0171003341674805,
- -5.9250402450561523,
- -5.8535833358764648,
- -5.8209109306335449,
- -5.8118929862976074,
- -5.80783748626709,
- -5.7714629173278809,
- -5.7453732490539551,
- -5.7705655097961426,
- -5.7765641212463379,
- -5.7831673622131348,
- -5.7954087257385254,
- -5.7994823455810547,
- -5.8023476600646973,
- -5.8047118186950684,
- -5.8168182373046875,
- -5.8844799995422363,
- -5.9727106094360352,
- -6.0444660186767578,
- -6.1284866333007812,
- -6.2257585525512695,
- -6.3157496452331543,
- -6.39061164855957,
- -6.4928598403930664,
- -6.5498456954956055,
- -6.6054320335388184,
- -6.6508378982543945,
- -6.66917610168457,
- -6.6726889610290527,
- -6.684234619140625,
- -6.6974577903747559,
- -6.75471830368042,
- -6.7949142456054688,
- -6.8634209632873535,
- -6.94186544418335
+ -7.6047816276550293,
+ -7.1206226348876953,
+ -6.8864245414733887,
+ -6.8705768585205078,
+ -6.9667720794677734,
+ -7.1084094047546387,
+ -6.9528026580810547,
+ -6.783994197845459,
+ -6.6195521354675293,
+ -6.4876265525817871,
+ -6.4120659828186035,
+ -6.394047737121582,
+ -6.4244871139526367,
+ -6.3993711471557617,
+ -6.5158271789550781,
+ -6.7137999534606934,
+ -6.8476877212524414,
+ -6.9885001182556152,
+ -6.9221386909484863,
+ -7.146148681640625,
+ -7.2040400505065918,
+ -7.0537552833557129,
+ -7.3140382766723633,
+ -7.1223249435424805,
+ -7.30251407623291,
+ -7.1212143898010254,
+ -7.2425732612609863,
+ -7.1730537414550781,
+ -7.0979413986206055,
+ -7.088747501373291,
+ -6.9849910736083984,
+ -6.8787732124328613,
+ -6.7602753639221191,
+ -6.6300945281982422,
+ -6.5145769119262695,
+ -6.4245057106018066,
+ -6.356513500213623,
+ -6.31787633895874,
+ -6.2660770416259766,
+ -6.2468328475952148,
+ -6.2821526527404785,
+ -6.1908388137817383,
+ -6.2484354972839355,
+ -6.1472640037536621,
+ -6.0924725532531738,
+ -6.0171003341674805,
+ -5.9250402450561523,
+ -5.8535833358764648,
+ -5.8209109306335449,
+ -5.8118929862976074,
+ -5.80783748626709,
+ -5.7714629173278809,
+ -5.7453732490539551,
+ -5.7705655097961426,
+ -5.7765641212463379,
+ -5.7831673622131348,
+ -5.7954087257385254,
+ -5.7994823455810547,
+ -5.8023476600646973,
+ -5.8047118186950684,
+ -5.8168182373046875,
+ -5.8844799995422363,
+ -5.9727106094360352,
+ -6.0444660186767578,
+ -6.1284866333007812,
+ -6.2257585525512695,
+ -6.3157496452331543,
+ -6.39061164855957,
+ -6.4928598403930664,
+ -6.5498456954956055,
+ -6.6054320335388184,
+ -6.6508378982543945,
+ -6.66917610168457,
+ -6.6726889610290527,
+ -6.684234619140625,
+ -6.6974577903747559,
+ -6.75471830368042,
+ -6.7949142456054688,
+ -6.8634209632873535,
+ -6.94186544418335,
]
LIBRISPEECH_STD_VECTOR = [
- 3.4353282451629639,
- 3.5962932109832764,
- 3.7012472152709961,
- 3.7369205951690674,
- 3.7535104751586914,
- 3.693629264831543,
- 3.6922497749328613,
- 3.7641522884368896,
- 3.8419716358184814,
- 3.8999848365783691,
- 3.9294240474700928,
- 3.9317409992218018,
- 3.9139585494995117,
- 3.9031598567962646,
- 3.8691999912261963,
- 3.8155081272125244,
- 3.7644970417022705,
- 3.7099106311798096,
- 3.6965086460113525,
- 3.6003766059875488,
- 3.5493226051330566,
- 3.5465121269226074,
- 3.45003604888916,
- 3.4712812900543213,
- 3.4084610939025879,
- 3.4408135414123535,
- 3.4104881286621094,
- 3.4217638969421387,
- 3.4312851428985596,
- 3.4199209213256836,
- 3.4305806159973145,
- 3.4382665157318115,
- 3.4580366611480713,
- 3.4817991256713867,
- 3.4958710670471191,
- 3.5036792755126953,
- 3.5047574043273926,
- 3.4988734722137451,
- 3.493056058883667,
- 3.4822943210601807,
- 3.459430456161499,
- 3.4612770080566406,
- 3.4559063911437988,
- 3.4755423069000244,
- 3.4971549510955811,
- 3.5326557159423828,
- 3.5705199241638184,
- 3.5920312404632568,
- 3.596907377243042,
- 3.5913500785827637,
- 3.5865931510925293,
- 3.5826809406280518,
- 3.5837743282318115,
- 3.5895791053771973,
- 3.5819313526153564,
- 3.5837869644165039,
- 3.5861184597015381,
- 3.5889589786529541,
- 3.592214822769165,
- 3.5939455032348633,
- 3.5856630802154541,
- 3.5884113311767578,
- 3.5921022891998291,
- 3.5870490074157715,
- 3.5806570053100586,
- 3.5731067657470703,
- 3.5617532730102539,
- 3.54980731010437,
- 3.5527374744415283,
- 3.5475366115570068,
- 3.5387849807739258,
- 3.5256178379058838,
- 3.5031836032867432,
- 3.4922726154327393,
- 3.4879646301269531,
- 3.4725594520568848,
- 3.4558389186859131,
- 3.4351828098297119,
- 3.4284293651580811,
- 3.4299170970916748
+ 3.4353282451629639,
+ 3.5962932109832764,
+ 3.7012472152709961,
+ 3.7369205951690674,
+ 3.7535104751586914,
+ 3.693629264831543,
+ 3.6922497749328613,
+ 3.7641522884368896,
+ 3.8419716358184814,
+ 3.8999848365783691,
+ 3.9294240474700928,
+ 3.9317409992218018,
+ 3.9139585494995117,
+ 3.9031598567962646,
+ 3.8691999912261963,
+ 3.8155081272125244,
+ 3.7644970417022705,
+ 3.7099106311798096,
+ 3.6965086460113525,
+ 3.6003766059875488,
+ 3.5493226051330566,
+ 3.5465121269226074,
+ 3.45003604888916,
+ 3.4712812900543213,
+ 3.4084610939025879,
+ 3.4408135414123535,
+ 3.4104881286621094,
+ 3.4217638969421387,
+ 3.4312851428985596,
+ 3.4199209213256836,
+ 3.4305806159973145,
+ 3.4382665157318115,
+ 3.4580366611480713,
+ 3.4817991256713867,
+ 3.4958710670471191,
+ 3.5036792755126953,
+ 3.5047574043273926,
+ 3.4988734722137451,
+ 3.493056058883667,
+ 3.4822943210601807,
+ 3.459430456161499,
+ 3.4612770080566406,
+ 3.4559063911437988,
+ 3.4755423069000244,
+ 3.4971549510955811,
+ 3.5326557159423828,
+ 3.5705199241638184,
+ 3.5920312404632568,
+ 3.596907377243042,
+ 3.5913500785827637,
+ 3.5865931510925293,
+ 3.5826809406280518,
+ 3.5837743282318115,
+ 3.5895791053771973,
+ 3.5819313526153564,
+ 3.5837869644165039,
+ 3.5861184597015381,
+ 3.5889589786529541,
+ 3.592214822769165,
+ 3.5939455032348633,
+ 3.5856630802154541,
+ 3.5884113311767578,
+ 3.5921022891998291,
+ 3.5870490074157715,
+ 3.5806570053100586,
+ 3.5731067657470703,
+ 3.5617532730102539,
+ 3.54980731010437,
+ 3.5527374744415283,
+ 3.5475366115570068,
+ 3.5387849807739258,
+ 3.5256178379058838,
+ 3.5031836032867432,
+ 3.4922726154327393,
+ 3.4879646301269531,
+ 3.4725594520568848,
+ 3.4558389186859131,
+ 3.4351828098297119,
+ 3.4284293651580811,
+ 3.4299170970916748,
]
@dataclass
class PreprocessorConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
sample_rate = 16000
frame_size_ms = 25
frame_step_ms = 10
@@ -203,10 +204,12 @@ class PreprocessorConfig:
def _hertz_to_mel(frequencies_hertz):
"""Convert hertz to mel."""
- log_fn = math.log if type(frequencies_hertz) in [type(0.0), type(0)
- ] else torch.log
- return _MEL_HIGH_FREQUENCY_Q * log_fn(1.0 + (frequencies_hertz /
- _MEL_BREAK_FREQUENCY_HERTZ))
+ log_fn = (
+ math.log if type(frequencies_hertz) in [type(0.0), type(0)] else torch.log
+ )
+ return _MEL_HIGH_FREQUENCY_Q * log_fn(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)
+ )
def _pad_end_length(num_timesteps, frame_step, frame_size):
@@ -218,28 +221,30 @@ def _pad_end_length(num_timesteps, frame_step, frame_size):
return padded_length - num_timesteps
-def frame(x,
- frame_length: int,
- frame_step: int,
- pad_end: bool = False,
- pad_value: Union[int, float] = 0.0):
+def frame(
+ x,
+ frame_length: int,
+ frame_step: int,
+ pad_end: bool = False,
+ pad_value: Union[int, float] = 0.0,
+):
"""Slides a window and extract values.
- This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
- stride of `frame_step`, and returns an array `y` with the shape
- `(batch_size, num_frames, frame_length, num_channels)`. Unlike the
- counterpart in Tensorflow (`tf.signal.frame`), this function currently
- does not take `axis` argument, and the input tensor `x` is expected to
- have a shape of `(batch_size, timesteps, channels)`.
- Args:
- x: An input array with `(batch_size, timesteps, channels)`-shape.
- frame_length: The frame length.
- frame_step: The frame hop size.
- pad_end: If True, the end of signal is padded so the window can continue
- sliding while the starting point of the window is in the valid range.
- pad_value: A scalar used as a padding value when `pad_end` is True.
- Returns:
- A tensor with shape `(*, num_frames, frame_length, num_channels)`.
- """
+ This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
+ stride of `frame_step`, and returns an array `y` with the shape
+ `(batch_size, num_frames, frame_length, num_channels)`. Unlike the
+ counterpart in Tensorflow (`tf.signal.frame`), this function currently
+ does not take `axis` argument, and the input tensor `x` is expected to
+ have a shape of `(batch_size, timesteps, channels)`.
+ Args:
+ x: An input array with `(batch_size, timesteps, channels)`-shape.
+ frame_length: The frame length.
+ frame_step: The frame hop size.
+ pad_end: If True, the end of signal is padded so the window can continue
+ sliding while the starting point of the window is in the valid range.
+ pad_value: A scalar used as a padding value when `pad_end` is True.
+ Returns:
+ A tensor with shape `(*, num_frames, frame_length, num_channels)`.
+ """
num_timesteps = x.shape[1]
if pad_end:
@@ -250,60 +255,67 @@ def frame(x,
return x.permute(0, 1, 3, 2)
-def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
- num_spectrogram_bins: int = 129,
- sample_rate: Union[int, float] = 8000,
- lower_edge_hertz: Union[int, float] = 125.0,
- upper_edge_hertz: Union[int, float] = 3800.0,
- dtype: Any = torch.float32,
- device='cpu'):
+def linear_to_mel_weight_matrix(
+ num_mel_bins: int = 20,
+ num_spectrogram_bins: int = 129,
+ sample_rate: Union[int, float] = 8000,
+ lower_edge_hertz: Union[int, float] = 125.0,
+ upper_edge_hertz: Union[int, float] = 3800.0,
+ dtype: Any = torch.float32,
+ device='cpu',
+):
r"""Pytorch-port of `tf.signal.linear_to_mel_weight_matrix`.
- Args:
- num_mel_bins: Python int. How many bands in the resulting mel spectrum.
- num_spectrogram_bins: An integer `Tensor`. How many bins there are in
- the source spectrogram data, which is understood to be `fft_size // 2 + 1`,
- i.e. the spectrogram only contains the nonredundant FFT bins.
- sample_rate: An integer or float `Tensor`. Samples per second of the
- input signal used to create the spectrogram. Used to figure out the
- frequencies corresponding to each spectrogram bin, which dictates how they
- are mapped into the mel scale.
- lower_edge_hertz: Python float. Lower bound on the frequencies to be
- included in the mel spectrum. This corresponds to the lower edge of the
- lowest triangular band.
- upper_edge_hertz: Python float. The desired top edge of the highest
- frequency band.
- dtype: The `DType` of the result matrix. Must be a floating point type.
- Returns:
- An array of shape `[num_spectrogram_bins, num_mel_bins]`.
- Raises:
- ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
- positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
- ordered, `upper_edge_hertz` is larger than the Nyquist frequency.
- [mel]: https://en.wikipedia.org/wiki/Mel_scale
- """
+ Args:
+ num_mel_bins: Python int. How many bands in the resulting mel spectrum.
+ num_spectrogram_bins: An integer `Tensor`. How many bins there are in
+ the source spectrogram data, which is understood to be `fft_size // 2 + 1`,
+ i.e. the spectrogram only contains the nonredundant FFT bins.
+ sample_rate: An integer or float `Tensor`. Samples per second of the
+ input signal used to create the spectrogram. Used to figure out the
+ frequencies corresponding to each spectrogram bin, which dictates how they
+ are mapped into the mel scale.
+ lower_edge_hertz: Python float. Lower bound on the frequencies to be
+ included in the mel spectrum. This corresponds to the lower edge of the
+ lowest triangular band.
+ upper_edge_hertz: Python float. The desired top edge of the highest
+ frequency band.
+ dtype: The `DType` of the result matrix. Must be a floating point type.
+ Returns:
+ An array of shape `[num_spectrogram_bins, num_mel_bins]`.
+ Raises:
+ ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
+ positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
+ ordered, `upper_edge_hertz` is larger than the Nyquist frequency.
+ [mel]: https://en.wikipedia.org/wiki/Mel_scale
+ """
# Input validator from tensorflow/python/ops/signal/mel_ops.py#L71
if num_mel_bins <= 0:
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
if lower_edge_hertz < 0.0:
- raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
- lower_edge_hertz)
+ raise ValueError(
+ 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz
+ )
if lower_edge_hertz >= upper_edge_hertz:
- raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
- (lower_edge_hertz, upper_edge_hertz))
+ raise ValueError(
+ 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f'
+ % (lower_edge_hertz, upper_edge_hertz)
+ )
if sample_rate <= 0.0:
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
if upper_edge_hertz > sample_rate / 2:
- raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
- 'frequency (sample_rate / 2). Got %s for sample_rate: %s' %
- (upper_edge_hertz, sample_rate))
+ raise ValueError(
+ 'upper_edge_hertz must not be larger than the Nyquist '
+ 'frequency (sample_rate / 2). Got %s for sample_rate: %s'
+ % (upper_edge_hertz, sample_rate)
+ )
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = torch.linspace(
- 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype,
- device=device)[bands_to_zero:]
+ 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype, device=device
+ )[bands_to_zero:]
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, None]
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
@@ -311,11 +323,12 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
# num_mel_bins + 2 pieces.
edges = torch.linspace(
- _hertz_to_mel(lower_edge_hertz),
- _hertz_to_mel(upper_edge_hertz),
- num_mel_bins + 2,
- dtype=dtype,
- device=device)
+ _hertz_to_mel(lower_edge_hertz),
+ _hertz_to_mel(upper_edge_hertz),
+ num_mel_bins + 2,
+ dtype=dtype,
+ device=device,
+ )
# Split the triples up and reshape them into [1, num_mel_bins] tensors.
lower_edge_mel = edges[:-2][None, :]
@@ -325,13 +338,16 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the mel domain, not Hertz.
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
- center_mel - lower_edge_mel)
+ center_mel - lower_edge_mel
+ )
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
- upper_edge_mel - center_mel)
+ upper_edge_mel - center_mel
+ )
# Intersect the line segments with each other and zero.
mel_weights_matrix = torch.minimum(lower_slopes, upper_slopes).clamp(
- min=0.0, max=None)
+ min=0.0, max=None
+ )
# Re-add the zeroed lower bins we sliced out above.
return F.pad(mel_weights_matrix, (0, 0, bands_to_zero, 0))
@@ -339,43 +355,50 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
def _hanning_greco(win_support, frame_size, dtype, device='cpu'):
"""Add a greco-style hanning window to the graph.
- Note that the Hanning window in Wikipedia is not the same as the Hanning
- window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia
- page would indicate. Talkin's explanation was that it was like wasting two
- samples to have the values at the edge of the window to be 0.0 exactly.
- Args:
- win_support: Number of samples for non-zero support in the window
- frame_size: Total size of the window (frame_size >= win_support)
- dtype: TF data type
- Returns:
- Tensor of size frame_size with the window to apply.
- """
+ Note that the Hanning window in Wikipedia is not the same as the Hanning
+ window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia
+ page would indicate. Talkin's explanation was that it was like wasting two
+ samples to have the values at the edge of the window to be 0.0 exactly.
+ Args:
+ win_support: Number of samples for non-zero support in the window
+ frame_size: Total size of the window (frame_size >= win_support)
+ dtype: TF data type
+ Returns:
+ Tensor of size frame_size with the window to apply.
+ """
if frame_size < win_support:
raise ValueError(
- 'Provided frame_size = {} is lower than win_support = {}'.format(
- frame_size, win_support))
+ 'Provided frame_size = {} is lower than win_support = {}'.format(
+ frame_size, win_support
+ )
+ )
arg = torch.pi * 2.0 / (win_support)
- hann = 0.5 - (0.5 * torch.cos(
- arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5)))
+ hann = 0.5 - (
+ 0.5
+ * torch.cos(
+ arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5)
+ )
+ )
zero_size = frame_size - win_support
return F.pad(hann, (0, zero_size))
def _next_pow_of_two(x: Union[int, float]) -> int:
- return int(2**np.ceil(np.log2(x)))
+ return int(2 ** np.ceil(np.log2(x)))
class SpectrogramFrontend(nn.Module):
- """Layer to convert input audio signals from time domain to frequency domain.
- """
-
- def __init__(self,
- config: PreprocessorConfig = None,
- input_scale_factor: float = 1.0,
- output_log: bool = False,
- dtype=torch.float32,
- device='cpu'):
+ """Layer to convert input audio signals from time domain to frequency domain."""
+
+ def __init__(
+ self,
+ config: PreprocessorConfig = None,
+ input_scale_factor: float = 1.0,
+ output_log: bool = False,
+ dtype=torch.float32,
+ device='cpu',
+ ):
super().__init__()
self.config = config
@@ -384,8 +407,9 @@ def __init__(self,
p = self.config
self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0))
- self._frame_size = int(round(
- p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph
+ self._frame_size = (
+ int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1
+ ) # +1 for the preemph
# TF-version has maximum of 512, but it's not always necessary
self.fft_size = _next_pow_of_two(self._frame_size)
@@ -399,23 +423,20 @@ def _hanning_window(frame_size, dtype):
if frame_size % 2 == 0:
# simulate periodic=True in tf.signal.hann_window
return torch.hann_window(
- window_length=frame_size,
- periodic=True,
- dtype=dtype,
- device=device)
+ window_length=frame_size, periodic=True, dtype=dtype, device=device
+ )
else:
return torch.hann_window(
- window_length=frame_size,
- periodic=False,
- dtype=dtype,
- device=device)
+ window_length=frame_size, periodic=False, dtype=dtype, device=device
+ )
self._window_fn = _hanning_window
elif p.window_fn.upper() == 'HANNING_GRECO':
# Greco-compatible hanning window
def f(frame_size, dtype):
return _hanning_greco(
- self._frame_size - 1, frame_size, dtype, device=device)
+ self._frame_size - 1, frame_size, dtype, device=device
+ )
self._window_fn = f
else:
@@ -430,25 +451,31 @@ def f(frame_size, dtype):
def _apply_preemphasis(self, framed_signal):
p = self.config
if p.preemph_htk_flavor:
- return torch.cat([
- framed_signal[:, :, :1, :] * (1. - p.preemph),
- (framed_signal[:, :, 1:-1, :] -
- p.preemph * framed_signal[:, :, :-2, :])
- ],
- dim=2)
+ return torch.cat(
+ [
+ framed_signal[:, :, :1, :] * (1.0 - p.preemph),
+ (
+ framed_signal[:, :, 1:-1, :]
+ - p.preemph * framed_signal[:, :, :-2, :]
+ ),
+ ],
+ dim=2,
+ )
else:
- return (framed_signal[:, :, 1:, :] -
- p.preemph * framed_signal[:, :, :-1, :])
+ return (
+ framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :]
+ )
def fprop_paddings(self, input_paddings):
p = self.config
if p.pad_end:
- num_extends = _pad_end_length(input_paddings.shape[1],
- self._frame_step,
- self._frame_size)
+ num_extends = _pad_end_length(
+ input_paddings.shape[1], self._frame_step, self._frame_size
+ )
input_paddings = F.pad(input_paddings, (0, num_extends), value=1.0)
x = input_paddings.unfold(
- dimension=1, size=self._frame_size, step=self._frame_step)
+ dimension=1, size=self._frame_size, step=self._frame_step
+ )
return x.min(dim=2)[0]
def forward(self, inputs, input_paddings):
@@ -467,7 +494,8 @@ def forward(self, inputs, input_paddings):
pcm_audio_chunk = inputs * self.input_scale_factor
framed_signal = frame(
- pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end)
+ pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end
+ )
if p.preemph != 0.0:
preemphasized = self._apply_preemphasis(framed_signal)
@@ -497,12 +525,14 @@ def forward(self, inputs, input_paddings):
class MelFilterbankFrontend(nn.Module):
"""Layer to compute log mel spectograms from input audio signals."""
- def __init__(self,
- config: PreprocessorConfig = None,
- use_divide_stream: bool = True,
- per_bin_mean: Optional[float] = None,
- per_bin_stddev: Optional[float] = None,
- device='cpu'):
+ def __init__(
+ self,
+ config: PreprocessorConfig = None,
+ use_divide_stream: bool = True,
+ per_bin_mean: Optional[float] = None,
+ per_bin_stddev: Optional[float] = None,
+ device='cpu',
+ ):
super().__init__()
self.config = config
@@ -513,7 +543,8 @@ def __init__(self,
input_scale_factor = 2**-15 if self.use_divide_stream else 1.0
self.stft = SpectrogramFrontend(
- p, input_scale_factor=input_scale_factor, output_log=False)
+ p, input_scale_factor=input_scale_factor, output_log=False
+ )
if self.per_bin_mean is None:
per_bin_mean = [0.0] * p.num_bins
@@ -525,10 +556,13 @@ def __init__(self,
else:
per_bin_stddev = self.per_bin_stddev
- self.register_buffer('_normalizer_mean',
- torch.FloatTensor(per_bin_mean)[None, None, :, None])
- self.register_buffer('_normalizer_stddev',
- torch.FloatTensor(per_bin_stddev)[None, None, :, None])
+ self.register_buffer(
+ '_normalizer_mean', torch.FloatTensor(per_bin_mean)[None, None, :, None]
+ )
+ self.register_buffer(
+ '_normalizer_stddev',
+ torch.FloatTensor(per_bin_stddev)[None, None, :, None],
+ )
def forward(self, inputs, input_paddings):
p = self.config
@@ -536,20 +570,24 @@ def forward(self, inputs, input_paddings):
spect, spect_paddings = self.stft(inputs, input_paddings)
mel_weights = linear_to_mel_weight_matrix(
- num_mel_bins=p.num_bins,
- num_spectrogram_bins=spect.shape[2],
- sample_rate=p.sample_rate,
- lower_edge_hertz=p.lower_edge_hertz,
- upper_edge_hertz=p.upper_edge_hertz,
- device=spect.device)
+ num_mel_bins=p.num_bins,
+ num_spectrogram_bins=spect.shape[2],
+ sample_rate=p.sample_rate,
+ lower_edge_hertz=p.lower_edge_hertz,
+ upper_edge_hertz=p.upper_edge_hertz,
+ device=spect.device,
+ )
mel_spectrogram = torch.einsum('fn,btfc->btnc', mel_weights, spect)
logmel_spectrogram = torch.log(
- mel_spectrogram.clamp(min=p.output_floor, max=None))
+ mel_spectrogram.clamp(min=p.output_floor, max=None)
+ )
normalized_logmel_spectrogram = (
- (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev)
+ logmel_spectrogram - self._normalizer_mean
+ ) / self._normalizer_stddev
- normalized_logmel_spectrogram = torch.squeeze(normalized_logmel_spectrogram,
- -1)
+ normalized_logmel_spectrogram = torch.squeeze(
+ normalized_logmel_spectrogram, -1
+ )
return normalized_logmel_spectrogram, spect_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py
index 11b93703e..66db657b8 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py
@@ -9,19 +9,21 @@
class SpecAug(nn.Module):
"""Layer performs masking prodecure along time and frequency axis.
- The procedure is detailed in https://arxiv.org/abs/1904.08779.
- This is an essential component in speech recognition models that
- helps achieve better word error rates.
- """
-
- def __init__(self,
- freq_mask_count: int = 1,
- freq_mask_max_bins: int = 15,
- time_mask_count: int = 1,
- time_mask_max_frames: int = 50,
- time_mask_max_ratio: float = 1.0,
- time_masks_per_frame: float = 0.0,
- use_dynamic_time_mask_max_frames: bool = False):
+ The procedure is detailed in https://arxiv.org/abs/1904.08779.
+ This is an essential component in speech recognition models that
+ helps achieve better word error rates.
+ """
+
+ def __init__(
+ self,
+ freq_mask_count: int = 1,
+ freq_mask_max_bins: int = 15,
+ time_mask_count: int = 1,
+ time_mask_max_frames: int = 50,
+ time_mask_max_ratio: float = 1.0,
+ time_masks_per_frame: float = 0.0,
+ use_dynamic_time_mask_max_frames: bool = False,
+ ):
super().__init__()
self.freq_mask_count = freq_mask_count
@@ -35,23 +37,26 @@ def __init__(self,
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
- def _get_mask(self,
- batch_size,
- choose_range,
- mask_size,
- max_length=None,
- masks_per_frame=0.0,
- multiplicity=1,
- max_ratio=1.0,
- device='cpu'):
+ def _get_mask(
+ self,
+ batch_size,
+ choose_range,
+ mask_size,
+ max_length=None,
+ masks_per_frame=0.0,
+ multiplicity=1,
+ max_ratio=1.0,
+ device='cpu',
+ ):
# Sample lengths for multiple masks.
if max_length and max_length > 0:
max_length = max_length * torch.ones(batch_size, device=device)
else:
max_length = choose_range * max_ratio
masked_portion = torch.rand(batch_size, multiplicity, device=device)
- masked_frame_size = torch.einsum('b,bm->bm', max_length,
- masked_portion).long()
+ masked_frame_size = torch.einsum(
+ 'b,bm->bm', max_length, masked_portion
+ ).long()
# Make sure the sampled length was smaller than max_ratio * length_bound.
# Note that sampling in this way was biased
# (shorter sequence may over-masked.)
@@ -80,8 +85,9 @@ def _get_mask(self,
# Sum masks with appropriate multiplicity.
if masks_per_frame > 0:
multiplicity_weights = torch.tile(
- torch.arange(multiplicity, device=device).long()[None, ...],
- [batch_size, 1])
+ torch.arange(multiplicity, device=device).long()[None, ...],
+ [batch_size, 1],
+ )
multiplicity_tensor = masks_per_frame * choose_range
multiplicity_weights = (multiplicity_weights < multiplicity_tensor).long()
pre_mask = torch.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
@@ -99,8 +105,9 @@ def _time_mask(self, inputs, length):
max_ratio = self.time_mask_max_ratio
# If maximum mask length is zero, do nothing.
- if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or
- max_ratio <= 0.0):
+ if (
+ time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames
+ ) or max_ratio <= 0.0:
return inputs
if multiplicity == 0:
return inputs
@@ -112,14 +119,15 @@ def _time_mask(self, inputs, length):
time_mask_max_frames = None
# Create masks in time direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=length,
- mask_size=time_length,
- max_length=time_mask_max_frames,
- masks_per_frame=self.time_masks_per_frame,
- multiplicity=multiplicity,
- max_ratio=max_ratio,
- device=inputs.device)
+ batch_size,
+ choose_range=length,
+ mask_size=time_length,
+ max_length=time_mask_max_frames,
+ masks_per_frame=self.time_masks_per_frame,
+ multiplicity=multiplicity,
+ max_ratio=max_ratio,
+ device=inputs.device,
+ )
outputs = torch.einsum('bxy,bx->bxy', inputs, block_arrays)
return outputs
@@ -138,14 +146,15 @@ def _frequency_mask(self, inputs):
choose_range = num_freq * torch.ones(batch_size, device=inputs.device)
# Create masks in frequency direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=choose_range,
- mask_size=num_freq,
- max_length=freq_mask_max_bins,
- masks_per_frame=0.0,
- multiplicity=multiplicity,
- max_ratio=1.0,
- device=inputs.device)
+ batch_size,
+ choose_range=choose_range,
+ mask_size=num_freq,
+ max_length=freq_mask_max_bins,
+ masks_per_frame=0.0,
+ multiplicity=multiplicity,
+ max_ratio=1.0,
+ device=inputs.device,
+ )
outputs = torch.einsum('bxy,by->bxy', inputs, block_arrays)
return outputs
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index 5ed37957e..25416682c 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -10,15 +10,12 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
import algoperf.random_utils as prng
-from algoperf.workloads.librispeech_conformer import metrics
-from algoperf.workloads.librispeech_conformer import workload
-from algoperf.workloads.librispeech_conformer.input_pipeline import \
- LibriSpeechDataset
+from algoperf import data_utils, param_utils, pytorch_utils, spec
+from algoperf.workloads.librispeech_conformer import metrics, workload
+from algoperf.workloads.librispeech_conformer.input_pipeline import (
+ LibriSpeechDataset,
+)
from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
@@ -27,10 +24,9 @@
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
-
- def __init__(self,
- tokenizer_vocab_path: Optional[str] = None,
- use_specaug: bool = True) -> None:
+ def __init__(
+ self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True
+ ) -> None:
super().__init__()
self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path)
self.use_specaug = use_specaug
@@ -42,7 +38,8 @@ def __init__(self,
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -61,16 +58,8 @@ def use_gelu(self) -> bool:
def attention_temperature(self) -> float:
return 1.0
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Conformer model init function.
-
- Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as
- input_dropout_rate.
- """
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
+ """Conformer model init function."""
torch.random.manual_seed(rng[0])
# Configure torch backends to avoid OOM errors.
torch.backends.cudnn.benchmark = False
@@ -82,15 +71,13 @@ def init_model_fn(
else:
activation_function_name = 'swish'
model = models.ConformerEncoderDecoder(
- models.ConformerConfig(
- attention_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- conv_residual_dropout_rate=dropout_rate,
- input_dropout_rate=aux_dropout_rate,
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name))
+ models.ConformerConfig(
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name,
+ )
+ )
self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none')
models.initialize(model)
self._param_shapes = param_utils.pytorch_param_shapes(model)
@@ -109,13 +96,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -125,29 +114,33 @@ def model_fn(
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
- functools.partial(
- pytorch_utils.update_batch_norm_fn,
- update_batch_norm=update_batch_norm))
+ functools.partial(
+ pytorch_utils.update_batch_norm_fn,
+ update_batch_norm=update_batch_norm,
+ )
+ )
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
- logits, logits_paddings = model(inputs.to(DEVICE),
- input_paddings.to(DEVICE))
+ logits, logits_paddings = model(
+ inputs.to(DEVICE), input_paddings.to(DEVICE), dropout_rate=dropout_rate
+ )
return (logits, logits_paddings), None
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
del repeat_final_dataset
del num_batches
@@ -166,7 +159,7 @@ def _build_input_queue(
if split == 'eval_train':
indices = list(range(len(ds)))
random.Random(int(data_rng[0])).shuffle(indices)
- ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples])
+ ds = torch.utils.data.Subset(ds, indices[: self.num_eval_train_examples])
sampler = None
if USE_PYTORCH_DDP:
@@ -177,31 +170,36 @@ def _build_input_queue(
if USE_PYTORCH_DDP:
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
- ds, num_replicas=N_GPUS, rank=RANK, shuffle=True)
+ ds, num_replicas=N_GPUS, rank=RANK, shuffle=True
+ )
else:
sampler = data_utils.DistributedEvalSampler(
- ds, num_replicas=N_GPUS, rank=RANK, shuffle=False)
+ ds, num_replicas=N_GPUS, rank=RANK, shuffle=False
+ )
dataloader = torch.utils.data.DataLoader(
- ds,
- batch_size=ds_iter_batch_size,
- shuffle=not USE_PYTORCH_DDP and is_train,
- sampler=sampler,
- num_workers=4,
- pin_memory=True,
- drop_last=is_train)
+ ds,
+ batch_size=ds_iter_batch_size,
+ shuffle=not USE_PYTORCH_DDP and is_train,
+ sampler=sampler,
+ num_workers=4,
+ pin_memory=True,
+ drop_last=is_train,
+ )
dataloader = data_utils.cycle(
- dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False)
+ dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False
+ )
return dataloader
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
- logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
+ logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -215,10 +213,8 @@ def loss_fn(
input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long()
target_lengths = torch.einsum('bh->b', 1 - target_paddings).long()
per_example_losses = self.ctc_loss(
- logprobs.permute(1, 0, 2),
- targets.long(),
- input_lengths,
- target_lengths)
+ logprobs.permute(1, 0, 2), targets.long(), input_lengths, target_lengths
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -229,21 +225,22 @@ def loss_fn(
summed_loss = per_example_losses.sum()
n_valid_examples = max(n_valid_examples, 1)
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def greedy_decode(
- self, logits: spec.Tensor,
- logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]:
+ self, logits: spec.Tensor, logit_paddings: spec.Tensor
+ ) -> Tuple[spec.Tensor, spec.Tensor]:
framewise_tokens = logits.max(dim=-1)[1]
framewise_tokens = framewise_tokens * (1 - logit_paddings)
# Add sentinel because unique_consecutive will flatten array
# and then compute the unique.
framewise_tokens = torch.cat(
- [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1)
+ [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1
+ )
_, indices = torch.unique_consecutive(framewise_tokens, return_inverse=True)
indices -= indices.min(dim=1, keepdims=True)[0]
result = torch.zeros_like(framewise_tokens)
@@ -256,11 +253,12 @@ def greedy_decode(
# Remove blanks (id = 0).
blank_id = 0
fin_result = torch.zeros_like(result)
- idxs = torch.arange(
- fin_result.numel(), device=result.device).view(*fin_result.shape)
- mask = torch.arange(
- fin_result.shape[1], device=result.device).view(
- 1, -1) < result.count_nonzero(dim=1).view(-1, 1)
+ idxs = torch.arange(fin_result.numel(), device=result.device).view(
+ *fin_result.shape
+ )
+ mask = torch.arange(fin_result.shape[1], device=result.device).view(
+ 1, -1
+ ) < result.count_nonzero(dim=1).view(-1, 1)
fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id]
padding = fin_result == 0
return fin_result, padding
@@ -274,29 +272,31 @@ def sync_sd(self, params: spec.ParameterContainer) -> None:
sd[k] = sd[k] / N_GPUS
params.load_state_dict(sd)
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
# These iterators repeat indefinitely.
- self._eval_iters[split] = (
- self._build_input_queue(
- data_rng, split, data_dir, global_batch_size=global_batch_size))
+ self._eval_iters[split] = self._build_input_queue(
+ data_rng, split, data_dir, global_batch_size=global_batch_size
+ )
total_metrics = {
- 'loss': torch.tensor(0., device=DEVICE),
- 'lengths': torch.tensor(0., device=DEVICE),
- 'word_errors': torch.tensor(0., device=DEVICE),
- 'num_words': torch.tensor(0., device=DEVICE),
+ 'loss': torch.tensor(0.0, device=DEVICE),
+ 'lengths': torch.tensor(0.0, device=DEVICE),
+ 'word_errors': torch.tensor(0.0, device=DEVICE),
+ 'num_words': torch.tensor(0.0, device=DEVICE),
}
num_batches = int(math.ceil(num_examples / global_batch_size))
if self.requires_sync_before_eval:
@@ -305,48 +305,50 @@ def _eval_model_on_split(self,
batch = next(self._eval_iters[split])
(logits, logits_padding), _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- model_rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ model_rng,
+ update_batch_norm=False,
+ )
decoded, decoded_paddings = self.greedy_decode(logits, logits_padding)
targets, target_paddings = batch['targets']
word_errors, num_words = metrics.compute_wer(
- decoded=decoded.cpu().numpy(),
- decoded_paddings=decoded_paddings.cpu().numpy(),
- targets=targets.cpu().numpy(),
- target_paddings=target_paddings.cpu().numpy(),
- tokenizer=self.tokenizer)
+ decoded=decoded.cpu().numpy(),
+ decoded_paddings=decoded_paddings.cpu().numpy(),
+ targets=targets.cpu().numpy(),
+ target_paddings=target_paddings.cpu().numpy(),
+ tokenizer=self.tokenizer,
+ )
loss = self.loss_fn((targets, target_paddings), (logits, logits_padding))
summed_loss = loss['summed']
lengths = loss['n_valid_examples']
batch_metrics = {
- 'loss': summed_loss,
- 'lengths': lengths,
- 'word_errors': word_errors,
- 'num_words': num_words,
+ 'loss': summed_loss,
+ 'lengths': lengths,
+ 'word_errors': word_errors,
+ 'num_words': num_words,
}
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
return {
- 'ctc_loss':
- float(total_metrics['loss'].item() /
- total_metrics['lengths'].item()),
- 'wer':
- float(total_metrics['word_errors'].item() /
- total_metrics['num_words'].item()),
+ 'ctc_loss': float(
+ total_metrics['loss'].item() / total_metrics['lengths'].item()
+ ),
+ 'wer': float(
+ total_metrics['word_errors'].item() / total_metrics['num_words'].item()
+ ),
}
class LibriSpeechConformerAttentionTemperatureWorkload(
- LibriSpeechConformerWorkload):
-
+ LibriSpeechConformerWorkload
+):
@property
def attention_temperature(self) -> float:
return 1.6
@@ -361,7 +363,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return False
@@ -376,7 +377,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):
-
@property
def use_gelu(self) -> bool:
return True
diff --git a/algoperf/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py
index de74cfe1b..7dd6a11dc 100644
--- a/algoperf/workloads/librispeech_conformer/metrics.py
+++ b/algoperf/workloads/librispeech_conformer/metrics.py
@@ -1,8 +1,8 @@
-from clu import metrics
import flax
import numpy as np
import tensorflow as tf
import tensorflow_text as tftxt
+from clu import metrics
gfile = tf.io.gfile
@@ -15,17 +15,20 @@ def average_ctc_loss():
@flax.struct.dataclass
class _Metric(metrics.Metric):
"""Applies `fun` and computes the average."""
+
total: np.float32
weight: np.float32
@classmethod
def from_model_output(cls, loss_dict, **_):
return cls(
- total=loss_dict['summed'], weight=loss_dict['n_valid_examples'])
+ total=loss_dict['summed'], weight=loss_dict['n_valid_examples']
+ )
def merge(self, other):
return type(self)(
- total=self.total + other.total, weight=self.weight + other.weight)
+ total=self.total + other.total, weight=self.weight + other.weight
+ )
def compute(self):
return self.total / self.weight
@@ -74,9 +77,10 @@ def edit_distance(source, target):
# possibilities and find minimum.
else:
distance[i][j] = 1 + min(
- distance[i][j - 1], # Insert
- distance[i - 1][j], # Remove
- distance[i - 1][j - 1]) # Replace
+ distance[i][j - 1], # Insert
+ distance[i - 1][j], # Remove
+ distance[i - 1][j - 1],
+ ) # Replace
return distance[num_source_words][num_target_words]
@@ -109,17 +113,20 @@ def compute_wer(decoded, decoded_paddings, targets, target_paddings, tokenizer):
return word_errors, num_words
-def load_tokenizer(model_path: str,
- add_bos: bool = False,
- add_eos: bool = True,
- reverse: bool = False):
+def load_tokenizer(
+ model_path: str,
+ add_bos: bool = False,
+ add_eos: bool = True,
+ reverse: bool = False,
+):
"""Load a tf-text SentencePiece tokenizer from given model filepath."""
if model_path is None:
return None
with gfile.GFile(model_path, 'rb') as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
- model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse)
+ model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse
+ )
return sp_tokenizer
@@ -128,8 +135,10 @@ def wer(tokenizer_vocab_path):
@flax.struct.dataclass
class WER(
- metrics.CollectingMetric.from_outputs(
- ('decoded', 'decoded_paddings', 'targets', 'target_paddings'))):
+ metrics.CollectingMetric.from_outputs(
+ ('decoded', 'decoded_paddings', 'targets', 'target_paddings')
+ )
+ ):
"""Computes the mean average precision for a binary classifier on CPU."""
def compute(self):
@@ -144,7 +153,8 @@ def compute(self):
values['decoded_paddings'],
values['targets'].astype(np.int32),
values['target_paddings'],
- tokenizer)
+ tokenizer,
+ )
return word_errors / num_words
@@ -153,4 +163,5 @@ def compute(self):
def get_metrics_bundle(tokenizer_vocab_path):
return metrics.Collection.create(
- ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path))
+ ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path)
+ )
diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py
index 94f01dd97..791270719 100644
--- a/algoperf/workloads/librispeech_conformer/workload.py
+++ b/algoperf/workloads/librispeech_conformer/workload.py
@@ -5,7 +5,6 @@
class BaseLibrispeechWorkload(spec.Workload):
-
_num_outputs: int = 1024
@property
@@ -25,8 +24,9 @@ def use_gelu(self) -> bool:
def attention_temperature(self) -> float:
raise NotImplementedError
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/wer'] < self.validation_target_value
@property
@@ -53,8 +53,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index b116f44cd..225852b28 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -1,4 +1,4 @@
-r"""Deepspeech.
+"""Deepspeech.
This model uses a deepspeech2 network to convert speech to text.
paper : https://arxiv.org/abs/1512.02595
@@ -10,16 +10,19 @@
from typing import Any, Dict, List, Optional, Tuple, Union
+import jax
+import jax.numpy as jnp
from flax import linen as nn
from flax import struct
-import jax
from jax.experimental import rnn
-import jax.numpy as jnp
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- librispeech_preprocessor as preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- spectrum_augmenter
+from algoperf.jax_utils import Dropout
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ librispeech_preprocessor as preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ spectrum_augmenter,
+)
Array = jnp.ndarray
StateType = Union[Array, Tuple[Array, ...]]
@@ -30,10 +33,13 @@
CarryHistory = Any
Output = Any
+DROPOUT_RATE = 0.1
+
@struct.dataclass
class DeepspeechConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
dtype: Any = jnp.float32
encoder_dim: int = 512
@@ -51,10 +57,6 @@ class DeepspeechConfig:
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.1.
- feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
@@ -69,50 +71,49 @@ class Subsample(nn.Module):
encoder_dim: model dimension of conformer.
input_dropout_rate: dropout rate for inputs.
"""
+
config: DeepspeechConfig
@nn.compact
- def __call__(self, inputs, output_paddings, train):
+ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE):
config = self.config
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
- encoder_dim=config.encoder_dim,
- dtype=config.dtype,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon,
- input_channels=1,
- output_channels=config.encoder_dim,
- use_tanh=config.use_tanh
- )(outputs, output_paddings, train)
+ encoder_dim=config.encoder_dim,
+ dtype=config.dtype,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ input_channels=1,
+ output_channels=config.encoder_dim,
+ use_tanh=config.use_tanh,
+ )(outputs, output_paddings, train)
outputs, output_paddings = Conv2dSubsampling(
- encoder_dim=config.encoder_dim,
- dtype=config.dtype,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon,
- input_channels=config.encoder_dim,
- output_channels=config.encoder_dim,
- use_tanh=config.use_tanh)(outputs, output_paddings, train)
+ encoder_dim=config.encoder_dim,
+ dtype=config.dtype,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ input_channels=config.encoder_dim,
+ output_channels=config.encoder_dim,
+ use_tanh=config.use_tanh,
+ )(outputs, output_paddings, train)
batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
outputs = jnp.reshape(
- outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
+ outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)
+ )
outputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
- outputs = nn.Dropout(
- rate=input_dropout_rate, deterministic=not train)(
- outputs)
+ outputs = Dropout(rate=dropout_rate, deterministic=not train)(
+ outputs, rate=dropout_rate
+ )
return outputs, output_paddings
@@ -124,6 +125,7 @@ class Conv2dSubsampling(nn.Module):
2) Also performs strided convolution over input_paddings to return the correct
paddings for downstream layers.
"""
+
input_channels: int = 0
output_channels: int = 0
filter_stride: List[int] = (2, 2)
@@ -136,24 +138,26 @@ class Conv2dSubsampling(nn.Module):
def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
- self.kernel = self.param('kernel',
- nn.initializers.xavier_uniform(),
- self.filter_shape)
+ self.kernel = self.param(
+ 'kernel', nn.initializers.xavier_uniform(), self.filter_shape
+ )
self.bias = self.param(
- 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
+ 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels
+ )
@nn.compact
def __call__(self, inputs, paddings, train):
# Computing strided convolution to subsample inputs.
feature_group_count = inputs.shape[3] // self.filter_shape[2]
outputs = jax.lax.conv_general_dilated(
- lhs=inputs,
- rhs=self.kernel,
- window_strides=self.filter_stride,
- padding=self.padding,
- rhs_dilation=(1, 1),
- dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
- feature_group_count=feature_group_count)
+ lhs=inputs,
+ rhs=self.kernel,
+ window_strides=self.filter_stride,
+ padding=self.padding,
+ rhs_dilation=(1, 1),
+ dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
+ feature_group_count=feature_group_count,
+ )
outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
@@ -168,56 +172,58 @@ def __call__(self, inputs, paddings, train):
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = jax.lax.conv_general_dilated(
- lhs=paddings[:, :, None],
- rhs=jnp.ones([1, 1, 1]),
- window_strides=self.filter_stride[:1],
- padding=[(0, pad_len)],
- dimension_numbers=('NHC', 'HIO', 'NHC'))
+ lhs=paddings[:, :, None],
+ rhs=jnp.ones([1, 1, 1]),
+ window_strides=self.filter_stride[:1],
+ padding=[(0, pad_len)],
+ dimension_numbers=('NHC', 'HIO', 'NHC'),
+ )
out_padding = jnp.squeeze(out_padding, axis=-1)
# Mask outputs by correct paddings to ensure padded elements in inputs map
# to padded value in outputs.
- outputs = outputs * (1.0 -
- jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
+ outputs = outputs * (
+ 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)
+ )
return outputs, out_padding
class FeedForwardModule(nn.Module):
"""Feedforward block of conformer layer."""
+
config: DeepspeechConfig
@nn.compact
- def __call__(self, inputs, input_paddings=None, train=False):
+ def __call__(
+ self, inputs, input_paddings=None, train=False, dropout_rate=DROPOUT_RATE
+ ):
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
- inputs = BatchNorm(config.encoder_dim,
- config.dtype,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)(inputs,
- input_paddings,
- train)
- inputs = nn.Dense(
+ inputs = BatchNorm(
config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.dtype,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon,
+ )(inputs, input_paddings, train)
+ inputs = nn.Dense(
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
if config.use_tanh:
inputs = nn.tanh(inputs)
else:
inputs = nn.relu(inputs)
inputs *= padding_mask
- if config.feed_forward_dropout_rate is None:
- feed_forward_dropout_rate = 0.1
- else:
- feed_forward_dropout_rate = config.feed_forward_dropout_rate
- inputs = nn.Dropout(rate=feed_forward_dropout_rate)(
- inputs, deterministic=not train)
+ inputs = Dropout(rate=dropout_rate)(
+ inputs, deterministic=not train, rate=dropout_rate
+ )
return inputs
@@ -232,6 +238,7 @@ class LayerNorm(nn.Module):
zeros, this differs from default flax implementation of multiplying by scale
and initializing to ones.
"""
+
dim: int = 0
epsilon: float = 1e-6
@@ -245,7 +252,7 @@ def __call__(self, inputs):
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
- normed_inputs *= (1 + self.scale)
+ normed_inputs *= 1 + self.scale
normed_inputs += self.bias
return normed_inputs
@@ -263,6 +270,7 @@ class BatchNorm(nn.Module):
and the corresponding defaults for momentum and epsilon have been copied over
from lingvo.
"""
+
encoder_dim: int = 0
dtype: Any = jnp.float32
batch_norm_momentum: float = 0.999
@@ -272,14 +280,12 @@ def setup(self):
dim = self.encoder_dim
dtype = self.dtype
- self.ra_mean = self.variable('batch_stats',
- 'mean',
- lambda s: jnp.zeros(s, dtype),
- dim)
- self.ra_var = self.variable('batch_stats',
- 'var',
- lambda s: jnp.ones(s, dtype),
- dim)
+ self.ra_mean = self.variable(
+ 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim
+ )
+ self.ra_var = self.variable(
+ 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim
+ )
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
@@ -308,7 +314,8 @@ def __call__(self, inputs, input_paddings=None, train=False):
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
- jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
+ jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True
+ )
sum_v = jax.lax.psum(sum_v, axis_name='batch')
count_v = jax.lax.psum(count_v, axis_name='batch')
@@ -345,15 +352,14 @@ class CudnnLSTM(nn.Module):
@nn.compact
def __call__(
- self,
- inputs: Array,
- segmentation_mask: Optional[Array] = None,
- return_carry: Optional[bool] = None,
- deterministic: bool = False,
- initial_states: Optional[Tuple[Array, Array]] = None,
- use_cuda: bool = True,
+ self,
+ inputs: Array,
+ segmentation_mask: Optional[Array] = None,
+ return_carry: Optional[bool] = None,
+ deterministic: bool = False,
+ initial_states: Optional[Tuple[Array, Array]] = None,
+ use_cuda: bool = True,
) -> Union[Array, Tuple[Array, Carry]]:
-
if jax.devices()[0].platform != 'gpu':
use_cuda = False
@@ -363,22 +369,22 @@ def __call__(
dropout = 0.0 if deterministic else self.dropout_rate
weights = self.param(
- 'weights',
- rnn.init_lstm_weight,
- input_size,
- self.features,
- self.num_layers,
- self.bidirectional,
+ 'weights',
+ rnn.init_lstm_weight,
+ input_size,
+ self.features,
+ self.num_layers,
+ self.bidirectional,
)
if initial_states is None:
h_0 = jnp.zeros(
- (num_directions * self.num_layers, batch_size, self.features),
- jnp.float32,
+ (num_directions * self.num_layers, batch_size, self.features),
+ jnp.float32,
)
c_0 = jnp.zeros(
- (num_directions * self.num_layers, batch_size, self.features),
- jnp.float32,
+ (num_directions * self.num_layers, batch_size, self.features),
+ jnp.float32,
)
else:
h_0, c_0 = initial_states
@@ -390,20 +396,35 @@ def __call__(
if use_cuda:
y, h, c = rnn.lstm(
- x=inputs, h_0=h_0, c_0=c_0, weights=weights,
- seq_lengths=seq_lengths, input_size=input_size,
- hidden_size=self.features, num_layers=self.num_layers,
- dropout=dropout, bidirectional=self.bidirectional,
+ x=inputs,
+ h_0=h_0,
+ c_0=c_0,
+ weights=weights,
+ seq_lengths=seq_lengths,
+ input_size=input_size,
+ hidden_size=self.features,
+ num_layers=self.num_layers,
+ dropout=dropout,
+ bidirectional=self.bidirectional,
)
else:
weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights(
- weights, input_size)
+ weights, input_size
+ )
y, h, c = rnn.lstm_ref(
- x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh,
- b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths,
- input_size=input_size, hidden_size=self.features,
- num_layers=self.num_layers, dropout=dropout,
- bidirectional=self.bidirectional,
+ x=inputs,
+ h_0=h_0,
+ c_0=c_0,
+ W_ih=weight_ih,
+ W_hh=weight_hh,
+ b_ih=bias_ih,
+ b_hh=bias_hh,
+ seq_lengths=seq_lengths,
+ input_size=input_size,
+ hidden_size=self.features,
+ num_layers=self.num_layers,
+ dropout=dropout,
+ bidirectional=self.bidirectional,
)
if return_carry:
@@ -413,21 +434,22 @@ def __call__(
@nn.nowrap
def unpack_weights(
- self, weights: Array, input_size: int
+ self, weights: Array, input_size: int
) -> Tuple[
- Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]:
+ Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]
+ ]:
return jax.experimental.rnn.unpack_lstm_weights(
- weights,
- input_size,
- self.features,
- self.num_layers,
- self.bidirectional,
+ weights,
+ input_size,
+ self.features,
+ self.num_layers,
+ self.bidirectional,
)
class BatchRNN(nn.Module):
- """Implements a single deepspeech encoder layer.
- """
+ """Implements a single deepspeech encoder layer."""
+
config: DeepspeechConfig
@nn.compact
@@ -437,16 +459,17 @@ def __call__(self, inputs, input_paddings, train):
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
- inputs = BatchNorm(config.encoder_dim,
- config.dtype,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)(inputs,
- input_paddings,
- train)
+ inputs = BatchNorm(
+ config.encoder_dim,
+ config.dtype,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon,
+ )(inputs, input_paddings, train)
output = CudnnLSTM(
- features=config.encoder_dim // 2,
- bidirectional=config.bidirectional,
- num_layers=1)(inputs, input_paddings)
+ features=config.encoder_dim // 2,
+ bidirectional=config.bidirectional,
+ num_layers=1,
+ )(inputs, input_paddings)
return output
@@ -458,22 +481,23 @@ class Deepspeech(nn.Module):
for each time step. The output is then fed into a CTC loss which eliminates
the need for alignment with targets.
"""
+
config: DeepspeechConfig
def setup(self):
config = self.config
self.specaug = spectrum_augmenter.SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames,
)
@nn.compact
- def __call__(self, inputs, input_paddings, train):
+ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
config = self.config
outputs = inputs
@@ -482,10 +506,10 @@ def __call__(self, inputs, input_paddings, train):
# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs,
- output_paddings)
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )(outputs, output_paddings)
# Ablate random parts of input along temporal and frequency dimension
# following the specaug procedure in https://arxiv.org/abs/1904.08779.
@@ -493,8 +517,9 @@ def __call__(self, inputs, input_paddings, train):
outputs, output_paddings = self.specaug(outputs, output_paddings)
# Subsample input by a factor of 4 by performing strided convolutions.
- outputs, output_paddings = Subsample(
- config=config)(outputs, output_paddings, train)
+ outputs, output_paddings = Subsample(config=config)(
+ outputs, output_paddings, train, dropout_rate=dropout_rate
+ )
# Run the lstm layers.
for _ in range(config.num_lstm_layers):
@@ -506,20 +531,21 @@ def __call__(self, inputs, input_paddings, train):
for _ in range(config.num_ffn_layers):
if config.enable_residual_connections:
outputs = outputs + FeedForwardModule(config=self.config)(
- outputs, output_paddings, train)
+ outputs, output_paddings, train
+ )
else:
- outputs = FeedForwardModule(config=self.config)(outputs,
- output_paddings,
- train)
+ outputs = FeedForwardModule(config=self.config)(
+ outputs, output_paddings, train, dropout_rate=dropout_rate
+ )
# Run the decoder which in this case is a trivial projection layer.
if config.enable_decoder_layer_norm:
outputs = LayerNorm(config.encoder_dim)(outputs)
outputs = nn.Dense(
- config.vocab_size,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ config.vocab_size,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index d3b616f43..b93934abf 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -1,40 +1,29 @@
import functools
from typing import Dict, Optional, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerWorkload
+from algoperf import param_utils, spec
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload,
+)
from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models
class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
-
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Deepspeech model init function.
-
- Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
- as input_dropout_rate.
- """
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
+ """Deepspeech model init function."""
model_config = models.DeepspeechConfig(
- feed_forward_dropout_rate=dropout_rate,
- use_specaug=self.use_specaug,
- input_dropout_rate=aux_dropout_rate,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count,
+ use_specaug=self.use_specaug,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
)
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
@@ -42,12 +31,17 @@ def init_model_fn(
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
- params_rng, dropout_rng = jax.random.split(rng, 2)
- variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
- *fake_input_batch)
+ params_rng, _ = jax.random.split(rng, 2)
+ variables = model_init_fn(
+ {
+ 'params': params_rng,
+ },
+ *fake_input_batch,
+ )
- model_state = variables[
- 'batch_stats'] if not self.layernorm_everywhere else {}
+ model_state = (
+ variables['batch_stats'] if not self.layernorm_everywhere else {}
+ )
params = variables['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -56,34 +50,34 @@ def init_model_fn(
return params, model_state
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: Optional[bool] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=True,
- rngs={'dropout' : rng},
- mutable=['batch_stats'])
+ variables,
+ inputs,
+ input_paddings,
+ train=True,
+ rngs={'dropout': rng},
+ mutable=['batch_stats'],
+ dropout_rate=dropout_rate,
+ )
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=False,
- mutable=False)
+ variables, inputs, input_paddings, train=False, mutable=False
+ )
return (logits, logit_paddings), model_state
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -132,7 +126,6 @@ def time_mask_count(self) -> int:
class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def use_tanh(self) -> bool:
return True
@@ -147,7 +140,6 @@ def test_target_value(self) -> float:
class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def enable_residual_connections(self) -> bool:
return False
@@ -161,9 +153,9 @@ def test_target_value(self) -> float:
return 0.079297
-class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
- ):
-
+class LibriSpeechDeepSpeechNormAndSpecAugWorkload(
+ LibriSpeechDeepSpeechWorkload
+):
@property
def eval_batch_size(self) -> int:
return 128
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index 84d317326..aab75da63 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -2,26 +2,30 @@
https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
"""
-from dataclasses import dataclass
import os
-from typing import Optional, Tuple
+from dataclasses import dataclass
+from typing import Tuple
import torch
-from torch import nn
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
+from torch import nn
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
- preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
- SpecAug
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import (
+ preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import (
+ SpecAug,
+)
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
+DROPOUT_RATE = 0.1
@dataclass
class DeepspeechConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
encoder_dim: int = 512
num_lstm_layers: int = 6
@@ -38,10 +42,6 @@ class DeepspeechConfig:
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.1.
- feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
@@ -50,7 +50,6 @@ class DeepspeechConfig:
class LayerNorm(nn.Module):
-
def __init__(self, dim, epsilon=1e-6):
super().__init__()
self.dim = dim
@@ -64,14 +63,13 @@ def forward(self, x):
var = x.var(dim=-1, unbiased=False, keepdims=True)
normed_x = (x - mean) * torch.rsqrt(var + self.epsilon)
- normed_x *= (1 + self.scale)
+ normed_x *= 1 + self.scale
normed_x += self.bias
return normed_x
class Subsample(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
encoder_dim = config.encoder_dim
@@ -79,21 +77,17 @@ def __init__(self, config: DeepspeechConfig):
self.encoder_dim = encoder_dim
self.conv1 = Conv2dSubsampling(
- input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
+ input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh
+ )
self.conv2 = Conv2dSubsampling(
- input_channels=encoder_dim,
- output_channels=encoder_dim,
- use_tanh=config.use_tanh)
+ input_channels=encoder_dim,
+ output_channels=encoder_dim,
+ use_tanh=config.use_tanh,
+ )
self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
- self.dropout = nn.Dropout(p=input_dropout_rate)
-
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -101,26 +95,27 @@ def forward(self, inputs, input_paddings):
outputs, output_paddings = self.conv2(outputs, output_paddings)
batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
- outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
- subsampled_lengths,
- subsampled_dims * channels)
+ outputs = outputs.permute(0, 2, 3, 1).reshape(
+ batch_size, subsampled_lengths, subsampled_dims * channels
+ )
outputs = self.lin(outputs)
- outputs = self.dropout(outputs)
+ outputs = F.dropout(outputs, dropout_rate, training=self.training)
return outputs, output_paddings
class Conv2dSubsampling(nn.Module):
-
- def __init__(self,
- input_channels: int,
- output_channels: int,
- filter_stride: Tuple[int] = (2, 2),
- padding: str = 'SAME',
- batch_norm_momentum: float = 0.999,
- batch_norm_epsilon: float = 0.001,
- use_tanh: bool = False):
+ def __init__(
+ self,
+ input_channels: int,
+ output_channels: int,
+ filter_stride: Tuple[int] = (2, 2),
+ padding: str = 'SAME',
+ batch_norm_momentum: float = 0.999,
+ batch_norm_epsilon: float = 0.001,
+ use_tanh: bool = False,
+ ):
super().__init__()
self.input_channels = input_channels
@@ -131,7 +126,8 @@ def __init__(self,
self.filter_shape = (output_channels, input_channels, 3, 3)
self.kernel = nn.Parameter(
- nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
+ nn.init.xavier_uniform_(torch.empty(*self.filter_shape))
+ )
self.bias = nn.Parameter(torch.zeros(output_channels))
self.use_tanh = use_tanh
@@ -162,12 +158,13 @@ def forward(self, inputs, paddings):
else:
in_ = inputs
outputs = F.conv2d(
- input=in_,
- weight=self.kernel,
- bias=self.bias,
- stride=self.filter_stride,
- dilation=(1, 1),
- groups=groups)
+ input=in_,
+ weight=self.kernel,
+ bias=self.bias,
+ stride=self.filter_stride,
+ dilation=(1, 1),
+ groups=groups,
+ )
if self.use_tanh:
outputs = F.tanh(outputs)
@@ -178,21 +175,24 @@ def forward(self, inputs, paddings):
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = F.conv1d(
- input=torch.cat([
- paddings[:, None, :],
- torch.zeros(
- size=(paddings.shape[0], 1, pad_len), device=paddings.device)
+ input=torch.cat(
+ [
+ paddings[:, None, :],
+ torch.zeros(
+ size=(paddings.shape[0], 1, pad_len), device=paddings.device
+ ),
],
- dim=2),
- weight=torch.ones([1, 1, 1], device=paddings.device),
- stride=self.filter_stride[:1])
+ dim=2,
+ ),
+ weight=torch.ones([1, 1, 1], device=paddings.device),
+ stride=self.filter_stride[:1],
+ )
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
return outputs, out_padding
class FeedForwardModule(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
@@ -201,17 +201,13 @@ def __init__(self, config: DeepspeechConfig):
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.bn_normalization_layer = BatchNorm(
- dim=config.encoder_dim,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon)
+ dim=config.encoder_dim,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ )
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
- if config.feed_forward_dropout_rate is None:
- feed_forward_dropout_rate = 0.1
- else:
- feed_forward_dropout_rate = config.feed_forward_dropout_rate
- self.dropout = nn.Dropout(p=feed_forward_dropout_rate)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
padding_mask = (1 - input_paddings)[:, :, None]
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
@@ -226,13 +222,12 @@ def forward(self, inputs, input_paddings):
inputs = F.relu(inputs)
inputs = inputs * padding_mask
- inputs = self.dropout(inputs)
+ inputs = F.dropout(inputs, dropout_rate, training=self.training)
return inputs
class BatchNorm(nn.Module):
-
def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
super().__init__()
running_mean = torch.zeros(dim)
@@ -247,8 +242,8 @@ def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
self.dim = dim
def forward(self, inputs, input_paddings):
- #inputs: NHD
- #padding: NH
+ # inputs: NHD
+ # padding: NH
mask = 1 - input_paddings[:, :, None]
if self.training:
count = mask.sum()
@@ -265,9 +260,11 @@ def forward(self, inputs, input_paddings):
var = sum_ / count
self.running_mean = (1 - self.momentum) * self.running_mean + (
- self.momentum) * mean.detach()
+ self.momentum
+ ) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
- self.momentum) * var.detach()
+ self.momentum
+ ) * var.detach()
else:
mean = self.running_mean
var = self.running_var
@@ -278,7 +275,6 @@ def forward(self, inputs, input_paddings):
class BatchRNN(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
@@ -290,19 +286,23 @@ def __init__(self, config: DeepspeechConfig):
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
- self.bn_normalization_layer = BatchNorm(config.encoder_dim,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)
+ self.bn_normalization_layer = BatchNorm(
+ config.encoder_dim,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon,
+ )
if bidirectional:
self.lstm = nn.LSTM(
- input_size=input_size,
- hidden_size=hidden_size // 2,
- bidirectional=True,
- batch_first=True)
+ input_size=input_size,
+ hidden_size=hidden_size // 2,
+ bidirectional=True,
+ batch_first=True,
+ )
else:
self.lstm = nn.LSTM(
- input_size=input_size, hidden_size=hidden_size, batch_first=True)
+ input_size=input_size, hidden_size=hidden_size, batch_first=True
+ )
def forward(self, inputs, input_paddings):
if self.config.layernorm_everywhere:
@@ -311,50 +311,59 @@ def forward(self, inputs, input_paddings):
inputs = self.bn_normalization_layer(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
- inputs, lengths, batch_first=True, enforce_sorted=False)
+ inputs, lengths, batch_first=True, enforce_sorted=False
+ )
packed_outputs, _ = self.lstm(packed_inputs)
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
- packed_outputs, batch_first=True)
+ packed_outputs, batch_first=True
+ )
if outputs.shape[1] < inputs.shape[1]:
- outputs = torch.cat([
+ outputs = torch.cat(
+ [
outputs,
torch.zeros(
- size=(outputs.shape[0],
- inputs.shape[1] - outputs.shape[1],
- outputs.shape[2]),
- device=outputs.device)
- ],
- dim=1)
+ size=(
+ outputs.shape[0],
+ inputs.shape[1] - outputs.shape[1],
+ outputs.shape[2],
+ ),
+ device=outputs.device,
+ ),
+ ],
+ dim=1,
+ )
return outputs
class DeepspeechEncoderDecoder(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
self.specaug = SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames,
)
preprocessing_config = preprocessor.PreprocessorConfig()
self.preprocessor = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )
self.subsample = Subsample(config=config)
self.lstms = nn.ModuleList(
- [BatchRNN(config) for _ in range(config.num_lstm_layers)])
+ [BatchRNN(config) for _ in range(config.num_lstm_layers)]
+ )
self.ffns = nn.ModuleList(
- [FeedForwardModule(config) for _ in range(config.num_ffn_layers)])
+ [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]
+ )
if config.enable_decoder_layer_norm:
self.ln = LayerNorm(config.encoder_dim)
@@ -363,14 +372,16 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(
+ outputs, output_paddings, dropout_rate
+ )
for idx in range(self.config.num_lstm_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.lstms[idx](outputs, output_paddings)
@@ -379,9 +390,11 @@ def forward(self, inputs, input_paddings):
for idx in range(self.config.num_ffn_layers):
if self.config.enable_residual_connections:
- outputs = outputs + self.ffns[idx](outputs, output_paddings)
+ outputs = outputs + self.ffns[idx](
+ outputs, output_paddings, dropout_rate
+ )
else:
- outputs = self.ffns[idx](outputs, output_paddings)
+ outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
if self.config.enable_decoder_layer_norm:
outputs = self.ln(outputs)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index e5387f5cb..672f3440f 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -1,19 +1,21 @@
-from typing import Optional
+from typing import Dict, Tuple
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- initialize
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechEncoderDecoder
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ initialize,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
+ DeepspeechConfig,
+ DeepspeechEncoderDecoder,
+)
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
@@ -21,29 +23,20 @@
class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
-
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Deepspeech model init function.
-
- Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
- as input_dropout_rate.
- """
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
+ """Deepspeech model init function."""
torch.random.manual_seed(rng[0])
model = DeepspeechEncoderDecoder(
- DeepspeechConfig(
- feed_forward_dropout_rate=dropout_rate,
- use_specaug=self.use_specaug,
- input_dropout_rate=aux_dropout_rate,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count)).eval()
+ DeepspeechConfig(
+ use_specaug=self.use_specaug,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
+ )
+ ).eval()
self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none')
# Run model once to initialize lazy layers.
t = MAX_INPUT_LENGTH
@@ -63,6 +56,28 @@ def init_model_fn(
model = torch.nn.DataParallel(model)
return model, None
+ def model_fn(
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ # override super method, changing only the default dropout_rate
+ # pylint: disable=useless-parent-delegation
+ return super().model_fn(
+ params,
+ augmented_and_preprocessed_input_batch,
+ model_state,
+ mode,
+ rng,
+ update_batch_norm,
+ dropout_rate,
+ )
+
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']
@@ -109,7 +124,6 @@ def time_mask_count(self) -> int:
class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def use_tanh(self) -> bool:
return True
@@ -124,7 +138,6 @@ def test_target_value(self) -> float:
class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def enable_residual_connections(self) -> bool:
return False
@@ -138,9 +151,9 @@ def test_target_value(self) -> float:
return 0.079297
-class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
- ):
-
+class LibriSpeechDeepSpeechNormAndSpecAugWorkload(
+ LibriSpeechDeepSpeechWorkload
+):
@property
def eval_batch_size(self) -> int:
return 128
diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py
index 5a4382da1..0d192cbf5 100644
--- a/algoperf/workloads/mnist/mnist_jax/workload.py
+++ b/algoperf/workloads/mnist/mnist_jax/workload.py
@@ -3,20 +3,18 @@
import functools
from typing import Any, Dict, Optional, Tuple
-from flax import jax_utils
-from flax import linen as nn
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from flax import linen as nn
+from jax import lax
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.mnist.workload import BaseMnistWorkload
class _Model(nn.Module):
-
@nn.compact
def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor:
del train
@@ -31,19 +29,12 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor:
class MnistWorkload(BaseMnistWorkload):
-
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
self._model = _Model()
- initial_params = self._model.init({'params': rng}, init_val,
- train=True)['params']
+ initial_params = self._model.init({'params': rng}, init_val, train=True)[
+ 'params'
+ ]
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
@@ -52,31 +43,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_1'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
logits_batch = self._model.apply(
- {'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- train=train)
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ train=train,
+ )
return logits_batch, None
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -86,7 +80,8 @@ def loss_fn(
one_hot_targets = jax.nn.one_hot(label_batch, 10)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
- smoothed_targets * nn.log_softmax(logits_batch), axis=-1)
+ smoothed_targets * nn.log_softmax(logits_batch), axis=-1
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -95,41 +90,45 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
accuracy = jnp.sum(
- (jnp.argmax(logits, axis=-1) == batch['targets']) * weights)
+ (jnp.argmax(logits, axis=-1) == batch['targets']) * weights
+ )
summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed']
metrics = {'accuracy': accuracy, 'loss': summed_loss}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
diff --git a/algoperf/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py
index 780e1bca0..b58898703 100644
--- a/algoperf/workloads/mnist/mnist_pytorch/workload.py
+++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py
@@ -1,18 +1,16 @@
"""MNIST workload implemented in PyTorch."""
-from collections import OrderedDict
import contextlib
+from collections import OrderedDict
from typing import Any, Dict, Iterator, Optional, Tuple
import torch
-from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
+from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import init_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import init_utils, param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.mnist.workload import BaseMnistWorkload
@@ -20,18 +18,20 @@
class _Model(nn.Module):
-
def __init__(self) -> None:
super().__init__()
input_size = 28 * 28
num_hidden = 128
num_classes = 10
self.net = nn.Sequential(
- OrderedDict([('layer1',
- torch.nn.Linear(input_size, num_hidden, bias=True)),
- ('layer1_sig', torch.nn.Sigmoid()),
- ('layer2',
- torch.nn.Linear(num_hidden, num_classes, bias=True))]))
+ OrderedDict(
+ [
+ ('layer1', torch.nn.Linear(input_size, num_hidden, bias=True)),
+ ('layer1_sig', torch.nn.Sigmoid()),
+ ('layer2', torch.nn.Linear(num_hidden, num_classes, bias=True)),
+ ]
+ )
+ )
def reset_parameters(self) -> None:
for m in self.net.modules():
@@ -44,16 +44,16 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class MnistWorkload(BaseMnistWorkload):
-
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
if N_GPUS != 0:
per_device_batch_size = int(global_batch_size / N_GPUS)
@@ -63,22 +63,27 @@ def _build_input_queue(
# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
- np_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size,
- num_batches,
- repeat_final_dataset)
+ np_iter = super()._build_input_queue(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ num_batches,
+ repeat_final_dataset,
+ )
while True:
if RANK == 0:
batch = next(np_iter) # pylint: disable=stop-iteration-return
inputs = torch.as_tensor(
- batch['inputs'], dtype=torch.float32, device=DEVICE)
+ batch['inputs'], dtype=torch.float32, device=DEVICE
+ )
targets = torch.as_tensor(
- batch['targets'], dtype=torch.long, device=DEVICE)
+ batch['targets'], dtype=torch.long, device=DEVICE
+ )
if 'weights' in batch:
weights = torch.as_tensor(
- batch['weights'], dtype=torch.bool, device=DEVICE)
+ batch['weights'], dtype=torch.bool, device=DEVICE
+ )
else:
weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE)
# Send batch to other devices when using DDP.
@@ -94,34 +99,37 @@ def _build_input_queue(
targets = targets.view(-1, *targets.shape[2:])
weights = weights.view(-1, *weights.shape[2:])
else:
- inputs = torch.empty((N_GPUS, per_device_batch_size, 28, 28, 1),
- dtype=torch.float32,
- device=DEVICE)
+ inputs = torch.empty(
+ (N_GPUS, per_device_batch_size, 28, 28, 1),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
dist.broadcast(inputs, src=0)
inputs = inputs[RANK]
- targets = torch.empty((N_GPUS, per_device_batch_size),
- dtype=torch.long,
- device=DEVICE)
+ targets = torch.empty(
+ (N_GPUS, per_device_batch_size), dtype=torch.long, device=DEVICE
+ )
dist.broadcast(targets, src=0)
targets = targets[RANK]
- weights = torch.empty((N_GPUS, per_device_batch_size),
- dtype=torch.bool,
- device=DEVICE)
+ weights = torch.empty(
+ (N_GPUS, per_device_batch_size), dtype=torch.bool, device=DEVICE
+ )
dist.broadcast(weights, src=0)
weights = weights[RANK]
batch = {
- 'inputs': inputs.permute(0, 3, 1, 2),
- 'targets': targets,
- 'weights': weights,
+ 'inputs': inputs.permute(0, 3, 1, 2),
+ 'targets': targets,
+ 'weights': weights,
}
yield batch
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ self,
+ rng: spec.RandomState,
+ dropout_rate: Optional[float] = None,
+ aux_dropout_rate: Optional[float] = None,
+ ) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
@@ -149,13 +157,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['net.layer2.weight', 'net_layer2.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -163,8 +172,8 @@ def model_fn(
if mode == spec.ForwardPassMode.EVAL:
model.eval()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
@@ -173,11 +182,12 @@ def model_fn(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -185,10 +195,11 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = F.cross_entropy(
- logits_batch,
- label_batch,
- reduction='none',
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ reduction='none',
+ label_smoothing=label_smoothing,
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -197,25 +208,27 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
@@ -227,8 +240,8 @@ def _eval_model(
return {'accuracy': accuracy, 'loss': summed_loss}
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py
index f53aadd0b..38006b9ac 100644
--- a/algoperf/workloads/mnist/workload.py
+++ b/algoperf/workloads/mnist/workload.py
@@ -10,10 +10,9 @@
import tensorflow_datasets as tfds
import torch
-from algoperf import data_utils
-from algoperf import spec
-from algoperf.pytorch_utils import pytorch_setup
import algoperf.random_utils as prng
+from algoperf import data_utils, spec
+from algoperf.pytorch_utils import pytorch_setup
USE_PYTORCH_DDP, _, _, _ = pytorch_setup()
@@ -23,16 +22,17 @@ def _normalize(image: spec.Tensor, mean: float, stddev: float) -> spec.Tensor:
def _build_mnist_dataset(
- data_rng: jax.random.PRNGKey,
- num_train_examples: int,
- num_validation_examples: int,
- train_mean: float,
- train_stddev: float,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: bool = False,
- repeat_final_dataset: bool = True) -> Iterator[Dict[str, spec.Tensor]]:
+ data_rng: jax.random.PRNGKey,
+ num_train_examples: int,
+ num_validation_examples: int,
+ train_mean: float,
+ train_stddev: float,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: bool = False,
+ repeat_final_dataset: bool = True,
+) -> Iterator[Dict[str, spec.Tensor]]:
shuffle = split in ['train', 'eval_train']
assert num_train_examples + num_validation_examples == 60000
if shuffle:
@@ -42,12 +42,14 @@ def _build_mnist_dataset(
else:
tfds_split = 'test'
ds = tfds.load(
- 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir)
+ 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir
+ )
ds = ds.map(
- lambda x: {
- 'inputs': _normalize(x['image'], train_mean, train_stddev),
- 'targets': x['label'],
- })
+ lambda x: {
+ 'inputs': _normalize(x['image'], train_mean, train_stddev),
+ 'targets': x['label'],
+ }
+ )
is_train = split == 'train'
if cache:
@@ -62,22 +64,23 @@ def _build_mnist_dataset(
ds = ds.repeat()
ds = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
return iter(ds)
class BaseMnistWorkload(spec.Workload):
-
@property
def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'accuracy'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/accuracy'] > self.validation_target_value
@property
@@ -104,8 +107,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -138,31 +142,33 @@ def eval_period_time_sec(self) -> int:
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
ds = _build_mnist_dataset(
- data_rng=data_rng,
- num_train_examples=self.num_train_examples,
- num_validation_examples=self.num_validation_examples,
- train_mean=self.train_mean,
- train_stddev=self.train_stddev,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- cache=cache,
- repeat_final_dataset=repeat_final_dataset)
+ data_rng=data_rng,
+ num_train_examples=self.num_train_examples,
+ num_validation_examples=self.num_validation_examples,
+ train_mean=self.train_mean,
+ train_stddev=self.train_stddev,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ cache=cache,
+ repeat_final_dataset=repeat_final_dataset,
+ )
return ds
@property
@@ -173,49 +179,52 @@ def step_hint(self) -> int:
return 7813
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
raise NotImplementedError
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng=data_rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- cache=True,
- repeat_final_dataset=True)
+ data_rng=data_rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ cache=True,
+ repeat_final_dataset=True,
+ )
total_metrics = {
- 'accuracy': 0.,
- 'loss': 0.,
+ 'accuracy': 0.0,
+ 'loss': 0.0,
}
num_batches = int(math.ceil(num_examples / global_batch_size))
num_devices = max(torch.cuda.device_count(), jax.local_device_count())
for _ in range(num_batches):
batch = next(self._eval_iters[split])
per_device_model_rngs = prng.split(model_rng, num_devices)
- batch_metrics = self._eval_model(params,
- batch,
- model_state,
- per_device_model_rngs)
+ batch_metrics = self._eval_model(
+ params, batch, model_state, per_device_model_rngs
+ )
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
return self._normalize_eval_metrics(num_examples, total_metrics)
diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py
index 3cb6f51de..79a4ddc4a 100644
--- a/algoperf/workloads/ogbg/input_pipeline.py
+++ b/algoperf/workloads/ogbg/input_pipeline.py
@@ -14,10 +14,10 @@
AVG_EDGES_PER_GRAPH = 56
TFDS_SPLIT_NAME = {
- 'train': 'train',
- 'eval_train': 'train',
- 'validation': 'validation',
- 'test': 'test',
+ 'train': 'train',
+ 'eval_train': 'train',
+ 'validation': 'validation',
+ 'test': 'test',
}
@@ -33,11 +33,12 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir):
read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng)
dataset = tfds.load(
- 'ogbg_molpcba:0.1.3',
- split=TFDS_SPLIT_NAME[split],
- shuffle_files=should_shuffle,
- read_config=read_config,
- data_dir=data_dir)
+ 'ogbg_molpcba:0.1.3',
+ split=TFDS_SPLIT_NAME[split],
+ shuffle_files=should_shuffle,
+ read_config=read_config,
+ data_dir=data_dir,
+ )
if should_shuffle:
dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15)
@@ -62,16 +63,17 @@ def _to_jraph(example):
receivers = edge_index[:, 1]
return jraph.GraphsTuple(
- n_node=num_nodes,
- n_edge=np.array([len(edge_index) * 2]),
- nodes=node_feat,
- edges=np.concatenate([edge_feat, edge_feat]),
- # Make the edges bidirectional
- senders=np.concatenate([senders, receivers]),
- receivers=np.concatenate([receivers, senders]),
- # Keep the labels with the graph for batching. They will be removed
- # in the processed batch.
- globals=np.expand_dims(labels, axis=0))
+ n_node=num_nodes,
+ n_edge=np.array([len(edge_index) * 2]),
+ nodes=node_feat,
+ edges=np.concatenate([edge_feat, edge_feat]),
+ # Make the edges bidirectional
+ senders=np.concatenate([senders, receivers]),
+ receivers=np.concatenate([receivers, senders]),
+ # Keep the labels with the graph for batching. They will be removed
+ # in the processed batch.
+ globals=np.expand_dims(labels, axis=0),
+ )
def _get_weights_by_nan_and_padding(labels, padding_mask):
@@ -123,10 +125,9 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
max_n_graphs = per_device_batch_size
jraph_iter = map(_to_jraph, dataset_iter)
- batched_iter = jraph.dynamically_batch(jraph_iter,
- max_n_nodes + 1,
- max_n_edges,
- max_n_graphs + 1)
+ batched_iter = jraph.dynamically_batch(
+ jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1
+ )
count = 0
graphs_shards = []
@@ -141,7 +142,8 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
graph = batched_graph._replace(globals={})
replaced_labels, weights = _get_weights_by_nan_and_padding(
- labels, jraph.get_graph_padding_mask(graph))
+ labels, jraph.get_graph_padding_mask(graph)
+ )
graphs_shards.append(graph)
labels_shards.append(replaced_labels)
@@ -156,9 +158,9 @@ def f(x):
labels_shards = f(labels_shards)
weights_shards = f(weights_shards)
yield {
- 'inputs': graphs_shards,
- 'targets': labels_shards,
- 'weights': weights_shards,
+ 'inputs': graphs_shards,
+ 'targets': labels_shards,
+ 'weights': weights_shards,
}
count = 0
@@ -170,5 +172,6 @@ def f(x):
def get_dataset_iter(split, data_rng, data_dir, global_batch_size):
shuffle = split in ['train', 'eval_train']
ds = _load_dataset(
- split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir)
+ split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir
+ )
return _get_batch_iterator(iter(ds), global_batch_size)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 55f83d905..7d41204c8 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -2,24 +2,23 @@
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py
from typing import Any
-from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
-from sklearn.metrics import average_precision_score
import torch
import torch.distributed as dist
+from clu import metrics
+from sklearn.metrics import average_precision_score
from algoperf.pytorch_utils import pytorch_setup
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
-def predictions_match_labels(*,
- logits: jnp.ndarray,
- labels: jnp.ndarray,
- **kwargs) -> jnp.ndarray:
+def predictions_match_labels(
+ *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs
+) -> jnp.ndarray:
"""Returns a binary array indicating where predictions match the labels."""
del kwargs # Unused.
preds = logits > 0
@@ -28,7 +27,8 @@ def predictions_match_labels(*,
@flax.struct.dataclass
class MeanAveragePrecision(
- metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
+ metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))
+):
"""Computes the mean average precision (mAP) over different tasks."""
def compute(self):
@@ -37,6 +37,7 @@ def compute(self):
labels = values['labels']
logits = values['logits']
mask = values['mask']
+ sigmoid = jax.nn.sigmoid
if USE_PYTORCH_DDP:
# Sync labels, logits, and masks across devices.
@@ -49,9 +50,14 @@ def compute(self):
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values
+ def sigmoid_np(x):
+ return 1 / (1 + np.exp(-x))
+
+ sigmoid = sigmoid_np
+
mask = mask.astype(bool)
- probs = jax.nn.sigmoid(logits)
+ probs = sigmoid(logits)
num_tasks = labels.shape[1]
average_precisions = np.full(num_tasks, np.nan)
@@ -62,7 +68,8 @@ def compute(self):
if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0:
is_labeled = mask[:, task]
average_precisions[task] = average_precision_score(
- labels[is_labeled, task], probs[is_labeled, task])
+ labels[is_labeled, task], probs[is_labeled, task]
+ )
# When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning.
if np.isnan(average_precisions).all():
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 0e66d2ab8..db1ca416c 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -1,21 +1,24 @@
# Forked from the init2winit implementation here
# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
-from typing import Optional, Tuple
+from typing import Tuple
-from flax import linen as nn
import jax.numpy as jnp
import jraph
+from flax import linen as nn
+from algoperf.jax_utils import Dropout
+
+DROPOUT_RATE = 0.1
-def _make_embed(latent_dim, name):
+def _make_embed(latent_dim, name):
def make_fn(inputs):
return nn.Dense(features=latent_dim, name=name)(inputs)
return make_fn
-def _make_mlp(hidden_dims, dropout, activation_fn):
+def _make_mlp(hidden_dims, activation_fn, train, dropout_rate=DROPOUT_RATE):
"""Creates a MLP with specified dimensions."""
@jraph.concatenated_args
@@ -25,7 +28,9 @@ def make_fn(inputs):
x = nn.Dense(features=dim)(x)
x = nn.LayerNorm()(x)
x = activation_fn(x)
- x = dropout(x)
+ x = Dropout(rate=dropout_rate, deterministic=not train)(
+ x, rate=dropout_rate
+ )
return x
return make_fn
@@ -36,28 +41,23 @@ class GNN(nn.Module):
The model assumes the input data is a jraph.GraphsTuple without global
variables. The final prediction will be encoded in the globals.
"""
+
num_outputs: int
latent_dim: int = 256
hidden_dims: Tuple[int] = (256,)
- # If None, defaults to 0.1.
- dropout_rate: Optional[float] = 0.1
num_message_passing_steps: int = 5
activation_fn_name: str = 'relu'
@nn.compact
- def __call__(self, graph, train):
- if self.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = self.dropout_rate
- dropout = nn.Dropout(rate=dropout_rate, deterministic=not train)
-
+ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
graph = graph._replace(
- globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
+ globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])
+ )
embedder = jraph.GraphMapFeatures(
- embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'),
- embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'))
+ embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'),
+ embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'),
+ )
graph = embedder(graph)
if self.activation_fn_name == 'relu':
@@ -68,16 +68,30 @@ def __call__(self, graph, train):
activation_fn = nn.silu
else:
raise ValueError(
- f'Invalid activation function name: {self.activation_fn_name}')
+ f'Invalid activation function name: {self.activation_fn_name}'
+ )
for _ in range(self.num_message_passing_steps):
net = jraph.GraphNetwork(
- update_edge_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
- update_node_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
- update_global_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
+ update_edge_fn=_make_mlp(
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate,
+ ),
+ update_node_fn=_make_mlp(
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate,
+ ),
+ update_global_fn=_make_mlp(
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate,
+ ),
+ )
graph = net(graph)
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index e895d15a7..8471fcdcc 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -1,47 +1,41 @@
"""OGBG workload implemented in Jax."""
+
import functools
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
import jraph
import optax
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.ogbg import metrics
from algoperf.workloads.ogbg.ogbg_jax import models
from algoperf.workloads.ogbg.workload import BaseOgbgWorkload
class OgbgWorkload(BaseOgbgWorkload):
-
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is unused."""
- del aux_dropout_rate
- rng, params_rng, dropout_rng = jax.random.split(rng, 3)
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
+ rng, params_rng = jax.random.split(rng, 2)
self._model = models.GNN(
- self._num_outputs,
- dropout_rate=dropout_rate,
- activation_fn_name=self.activation_fn_name,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps)
+ self._num_outputs,
+ activation_fn_name=self.activation_fn_name,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps,
+ )
init_fn = jax.jit(functools.partial(self._model.init, train=False))
fake_batch = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, self._num_outputs)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
- params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch)
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, self._num_outputs)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]),
+ )
+ params = init_fn({'params': params_rng}, fake_batch)
params = params['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -51,37 +45,45 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_17'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
raise ValueError(
- f'Expected model_state to be None, received {model_state}.')
+ f'Expected model_state to be None, received {model_state}.'
+ )
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train)
+ logits = self._model.apply(
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate,
+ )
return logits, None
def _binary_cross_entropy_with_mask(
- self,
- labels: jnp.ndarray,
- logits: jnp.ndarray,
- mask: jnp.ndarray,
- label_smoothing: float = 0.0) -> jnp.ndarray:
+ self,
+ labels: jnp.ndarray,
+ logits: jnp.ndarray,
+ mask: jnp.ndarray,
+ label_smoothing: float = 0.0,
+ ) -> jnp.ndarray:
"""Binary cross entropy loss for logits, with masked elements."""
if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens
raise ValueError(
- f'Shape mismatch between logits ({logits.shape}), targets '
- f'({labels.shape}), and weights ({mask.shape}).')
+ f'Shape mismatch between logits ({logits.shape}), targets '
+ f'({labels.shape}), and weights ({mask.shape}).'
+ )
if len(logits.shape) != 2:
raise ValueError(f'Rank of logits ({logits.shape}) must be 2.')
@@ -97,26 +99,31 @@ def _binary_cross_entropy_with_mask(
positive_logits = logits >= 0
relu_logits = jnp.where(positive_logits, logits, 0)
abs_logits = jnp.where(positive_logits, logits, -logits)
- losses = relu_logits - (logits * smoothed_labels) + (
- jnp.log(1 + jnp.exp(-abs_logits)))
- return jnp.where(mask, losses, 0.)
+ losses = (
+ relu_logits
+ - (logits * smoothed_labels)
+ + (jnp.log(1 + jnp.exp(-abs_logits)))
+ )
+ return jnp.where(mask, losses, 0.0)
def _eval_metric(self, labels, logits, masks):
loss = self.loss_fn(labels, logits, masks)
return metrics.EvalMetrics.single_from_model_output(
- loss=loss['per_example'], logits=logits, labels=labels, mask=masks)
+ loss=loss['per_example'], logits=logits, labels=labels, mask=masks
+ )
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def _eval_batch(self, params, batch, model_state, rng):
return super()._eval_batch(params, batch, model_state, rng)
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
total_metrics = total_metrics.reduce()
@@ -124,7 +131,6 @@ def _normalize_eval_metrics(
class OgbgGeluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -140,7 +146,6 @@ def test_target_value(self) -> float:
class OgbgSiluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -156,7 +161,6 @@ def test_target_value(self) -> float:
class OgbgModelSizeWorkload(OgbgWorkload):
-
@property
def hidden_dims(self) -> Tuple[int]:
return (256, 256)
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
index fe9b29bc1..a69bc6ee1 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
@@ -4,22 +4,26 @@
from typing import Callable, Optional, Tuple
import jax.tree_util as tree
-from jraph import GraphsTuple
import torch
+from jraph import GraphsTuple
from torch import nn
from algoperf import init_utils
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+
+DROPOUT_RATE = 0.1
-def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn):
+def _make_mlp(in_dim, hidden_dims, activation_fn):
"""Creates a MLP with specified dimensions."""
- layers = nn.Sequential()
+ layers = SequentialWithDropout()
for i, dim in enumerate(hidden_dims):
- layers.add_module(f'dense_{i}',
- nn.Linear(in_features=in_dim, out_features=dim))
+ layers.add_module(
+ f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)
+ )
layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6))
layers.add_module(f'activation_fn_{i}', activation_fn())
- layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate))
+ layers.add_module(f'dropout_{i}', CustomDropout())
in_dim = dim
return layers
@@ -31,20 +35,19 @@ class GNN(nn.Module):
variables. The final prediction will be encoded in the globals.
"""
- def __init__(self,
- num_outputs: int = 128,
- dropout_rate: Optional[float] = 0.1,
- activation_fn_name: str = 'relu',
- latent_dim: int = 256,
- hidden_dims: Tuple[int] = (256,),
- num_message_passing_steps: int = 5) -> None:
+ def __init__(
+ self,
+ num_outputs: int = 128,
+ activation_fn_name: str = 'relu',
+ latent_dim: int = 256,
+ hidden_dims: Tuple[int] = (256,),
+ num_message_passing_steps: int = 5,
+ ) -> None:
super().__init__()
self.latent_dim = latent_dim
self.hidden_dims = hidden_dims
self.num_message_passing_steps = num_message_passing_steps
self.num_outputs = num_outputs
- if dropout_rate is None:
- dropout_rate = 0.1
# in_features are specifically chosen for the ogbg workload.
self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim)
self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim)
@@ -57,7 +60,8 @@ def __init__(self,
activation_fn = nn.SiLU
else:
raise ValueError(
- f'Invalid activation function name: {self.activation_fn_name}')
+ f'Invalid activation function name: {self.activation_fn_name}'
+ )
graph_network_layers = []
for st in range(self.num_message_passing_steps):
@@ -65,8 +69,9 @@ def __init__(self,
# specifically update_edge_fn update_node_fn and update_global_fn.
if st == 0:
in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs
- in_dim_node_fn = self.latent_dim + self.hidden_dims[
- -1] * 2 + self.num_outputs
+ in_dim_node_fn = (
+ self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs
+ )
last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs
else:
in_dim_edge_fn = self.hidden_dims[-1] * 4
@@ -74,36 +79,40 @@ def __init__(self,
last_in_dim = self.hidden_dims[-1] * 3
graph_network_layers.append(
- GraphNetwork(
- update_edge_fn=_make_mlp(in_dim_edge_fn,
- self.hidden_dims,
- dropout_rate,
- activation_fn),
- update_node_fn=_make_mlp(in_dim_node_fn,
- self.hidden_dims,
- dropout_rate,
- activation_fn),
- update_global_fn=_make_mlp(last_in_dim,
- self.hidden_dims,
- dropout_rate,
- activation_fn)))
- self.graph_network = nn.Sequential(*graph_network_layers)
+ GraphNetwork(
+ update_edge_fn=_make_mlp(
+ in_dim_edge_fn, self.hidden_dims, activation_fn
+ ),
+ update_node_fn=_make_mlp(
+ in_dim_node_fn, self.hidden_dims, activation_fn
+ ),
+ update_global_fn=_make_mlp(
+ last_in_dim, self.hidden_dims, activation_fn
+ ),
+ )
+ )
+ self.graph_network = SequentialWithDropout(*graph_network_layers)
self.decoder = nn.Linear(
- in_features=self.hidden_dims[-1], out_features=self.num_outputs)
+ in_features=self.hidden_dims[-1], out_features=self.num_outputs
+ )
for m in self.modules():
if isinstance(m, nn.Linear):
init_utils.pytorch_default_init(m)
- def forward(self, graph: GraphsTuple) -> torch.Tensor:
+ def forward(
+ self, graph: GraphsTuple, dropout_rate: float = DROPOUT_RATE
+ ) -> torch.Tensor:
graph = graph._replace(
- globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
- device=graph.n_node.device))
+ globals=torch.zeros(
+ [graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device
+ )
+ )
graph = graph._replace(nodes=self.node_embedder(graph.nodes))
graph = graph._replace(edges=self.edge_embedder(graph.edges))
- graph = self.graph_network(graph)
+ graph = self.graph_network(graph, dropout_rate)
# Map globals to represent the final result
graph = graph._replace(globals=self.decoder(graph.globals))
@@ -137,16 +146,19 @@ class GraphNetwork(nn.Module):
A method that applies the configured GraphNetwork.
"""
- def __init__(self,
- update_edge_fn: Optional[Callable] = None,
- update_node_fn: Optional[Callable] = None,
- update_global_fn: Optional[Callable] = None) -> None:
+ def __init__(
+ self,
+ update_edge_fn: Optional[Callable] = None,
+ update_node_fn: Optional[Callable] = None,
+ update_global_fn: Optional[Callable] = None,
+ ) -> None:
super().__init__()
self.update_edge_fn = update_edge_fn
self.update_node_fn = update_node_fn
self.update_global_fn = update_global_fn
+ self._supports_custom_dropout = True # supports SequentialWithDropout
- def forward(self, graph: GraphsTuple) -> GraphsTuple:
+ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple:
"""Applies a configured GraphNetwork to a graph.
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
There is one difference. For the nodes update the class aggregates over the
@@ -159,42 +171,49 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple:
GraphNets, for more information please see the paper.
Args:
graph: a `GraphsTuple` containing the graph.
+ dropout_rate: dropout probability value.
Returns:
Updated `GraphsTuple`.
"""
nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
if not tree.tree_all(
- tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
+ tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)
+ ):
raise ValueError(
- 'All node arrays in nest must contain the same number of nodes.')
+ 'All node arrays in nest must contain the same number of nodes.'
+ )
sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
# Here we scatter the global features to the corresponding edges,
# giving us tensors of shape [num_edges, global_feat].
global_edge_attributes = tree.tree_map(
- lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_)
+ lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_
+ )
if self.update_edge_fn:
edge_fn_inputs = torch.cat(
- [edges, sent_attributes, received_attributes, global_edge_attributes],
- dim=-1)
- edges = self.update_edge_fn(edge_fn_inputs)
+ [edges, sent_attributes, received_attributes, global_edge_attributes],
+ dim=-1,
+ )
+ edges = self.update_edge_fn(edge_fn_inputs, dropout_rate)
if self.update_node_fn:
sent_attributes = tree.tree_map(
- lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges)
+ lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges
+ )
received_attributes = tree.tree_map(
- lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node),
- edges)
+ lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), edges
+ )
# Here we scatter the global features to the corresponding nodes,
# giving us tensors of shape [num_nodes, global_feat].
global_attributes = tree.tree_map(
- lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_)
+ lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_
+ )
node_fn_inputs = torch.cat(
- [nodes, sent_attributes, received_attributes, global_attributes],
- dim=-1)
- nodes = self.update_node_fn(node_fn_inputs)
+ [nodes, sent_attributes, received_attributes, global_attributes], dim=-1
+ )
+ nodes = self.update_node_fn(node_fn_inputs, dropout_rate)
if self.update_global_fn:
n_graph = n_node.shape[0]
@@ -207,31 +226,37 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple:
edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0)
# We use the aggregation function to pool the nodes/edges per graph.
node_attributes = tree.tree_map(
- lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes)
+ lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes
+ )
edge_attributes = tree.tree_map(
- lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges)
+ lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges
+ )
# These pooled nodes are the inputs to the global update fn.
- global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_],
- dim=-1)
- globals_ = self.update_global_fn(global_fn_inputs)
+ global_fn_inputs = torch.cat(
+ [node_attributes, edge_attributes, globals_], dim=-1
+ )
+ globals_ = self.update_global_fn(global_fn_inputs, dropout_rate)
return GraphsTuple(
- nodes=nodes,
- edges=edges,
- receivers=receivers,
- senders=senders,
- globals=globals_,
- n_node=n_node,
- n_edge=n_edge)
+ nodes=nodes,
+ edges=edges,
+ receivers=receivers,
+ senders=senders,
+ globals=globals_,
+ n_node=n_node,
+ n_edge=n_edge,
+ )
# Forked from
# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py.
-def scatter_sum(src: torch.Tensor,
- index: torch.Tensor,
- dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
+def scatter_sum(
+ src: torch.Tensor,
+ index: torch.Tensor,
+ dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None,
+) -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index 45295ac7f..f72ff5141 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -1,17 +1,17 @@
"""OGBG workload implemented in PyTorch."""
+
import contextlib
from typing import Any, Callable, Dict, Optional, Tuple
import jax
-from jraph import GraphsTuple
import torch
import torch.distributed as dist
+from jraph import GraphsTuple
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
+from algoperf import param_utils, pytorch_utils, spec
from algoperf.workloads.ogbg import metrics
+from algoperf.workloads.ogbg.ogbg_pytorch import models
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN
from algoperf.workloads.ogbg.workload import BaseOgbgWorkload
@@ -22,9 +22,11 @@ def _pytorch_map(inputs: Any) -> Any:
if USE_PYTORCH_DDP:
return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
return jax.tree.map(
- lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
- if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1),
- inputs)
+ lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
+ if len(a.shape) == 3
+ else torch.as_tensor(a, device=DEVICE).view(-1),
+ inputs,
+ )
def _shard(inputs: Any) -> Any:
@@ -35,44 +37,47 @@ def _shard(inputs: Any) -> Any:
def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple:
return GraphsTuple(
- nodes=function(graph.nodes),
- edges=function(graph.edges),
- receivers=function(graph.receivers),
- senders=function(graph.senders),
- globals=function(graph.globals),
- n_node=function(graph.n_node),
- n_edge=function(graph.n_edge))
+ nodes=function(graph.nodes),
+ edges=function(graph.edges),
+ receivers=function(graph.receivers),
+ senders=function(graph.senders),
+ globals=function(graph.globals),
+ n_node=function(graph.n_node),
+ n_edge=function(graph.n_edge),
+ )
class OgbgWorkload(BaseOgbgWorkload):
-
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
- loss_dict = super().loss_fn(label_batch,
- logits_batch,
- mask_batch,
- label_smoothing)
+ loss_dict = super().loss_fn(
+ label_batch, logits_batch, mask_batch, label_smoothing
+ )
loss_dict['n_valid_examples'] = torch.as_tensor(
- loss_dict['n_valid_examples'], device=DEVICE)
+ loss_dict['n_valid_examples'], device=DEVICE
+ )
return loss_dict
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int):
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ ):
# TODO: Check where the + 1 comes from.
per_device_batch_size = int(global_batch_size / N_GPUS) + 1
@@ -80,10 +85,9 @@ def _build_input_queue(self,
# avoid creating too many threads.
if RANK == 0:
data_rng = data_rng.astype('uint32')
- dataset_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size)
+ dataset_iter = super()._build_input_queue(
+ data_rng, split, data_dir, global_batch_size
+ )
while True:
if RANK == 0:
@@ -91,14 +95,16 @@ def _build_input_queue(self,
graph = _graph_map(_pytorch_map, batch['inputs'])
targets = torch.as_tensor(batch['targets'], device=DEVICE)
weights = torch.as_tensor(
- batch['weights'], dtype=torch.bool, device=DEVICE)
+ batch['weights'], dtype=torch.bool, device=DEVICE
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
dist.broadcast_object_list([graph], src=0, device=DEVICE)
# During eval, the batch size of the remainder might be different.
if split != 'train':
per_device_batch_size = torch.tensor(
- len(targets[0]), dtype=torch.int32, device=DEVICE)
+ len(targets[0]), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
dist.broadcast(targets, src=0)
targets = targets[0]
@@ -113,44 +119,40 @@ def _build_input_queue(self,
graph = graph[0]
# During eval, the batch size of the remainder might be different.
if split != 'train':
- per_device_batch_size = torch.empty((1,),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (1,), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
targets = torch.empty(
- (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE)
+ (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE
+ )
dist.broadcast(targets, src=0)
targets = targets[RANK]
weights = torch.empty(
- (N_GPUS, per_device_batch_size, self._num_outputs),
- dtype=torch.bool,
- device=DEVICE)
+ (N_GPUS, per_device_batch_size, self._num_outputs),
+ dtype=torch.bool,
+ device=DEVICE,
+ )
dist.broadcast(weights, src=0)
weights = weights[RANK]
batch = {
- 'inputs': _graph_map(_shard, graph),
- 'targets': targets,
- 'weights': weights,
+ 'inputs': _graph_map(_shard, graph),
+ 'targets': targets,
+ 'weights': weights,
}
yield batch
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is unused."""
- del aux_dropout_rate
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = GNN(
- num_outputs=self._num_outputs,
- dropout_rate=dropout_rate,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps,
- activation_fn_name=self.activation_fn_name)
+ num_outputs=self._num_outputs,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps,
+ activation_fn_name=self.activation_fn_name,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -165,19 +167,22 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['decoder.weight', 'decoder.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del rng
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
raise ValueError(
- f'Expected model_state to be None, received {model_state}.')
+ f'Expected model_state to be None, received {model_state}.'
+ )
model = params
if mode == spec.ForwardPassMode.TRAIN:
@@ -186,26 +191,31 @@ def model_fn(
model.eval()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
- logits = model(augmented_and_preprocessed_input_batch['inputs'])
+ logits = model(
+ augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate,
+ )
return logits, None
def _binary_cross_entropy_with_mask(
- self,
- labels: torch.Tensor,
- logits: torch.Tensor,
- mask: torch.Tensor,
- label_smoothing: float = 0.0) -> torch.Tensor:
+ self,
+ labels: torch.Tensor,
+ logits: torch.Tensor,
+ mask: torch.Tensor,
+ label_smoothing: float = 0.0,
+ ) -> torch.Tensor:
"""Binary cross entropy loss for logits, with masked elements."""
if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens
raise ValueError(
- f'Shape mismatch between logits ({logits.shape}), targets '
- f'({labels.shape}), and weights ({mask.shape}).')
+ f'Shape mismatch between logits ({logits.shape}), targets '
+ f'({labels.shape}), and weights ({mask.shape}).'
+ )
if len(logits.shape) != 2:
raise ValueError(f'Rank of logits ({logits.shape}) must be 2.')
@@ -215,36 +225,40 @@ def _binary_cross_entropy_with_mask(
# Apply label_smoothing.
num_classes = labels.shape[-1]
- smoothed_labels = ((1.0 - label_smoothing) * labels +
- label_smoothing / num_classes)
+ smoothed_labels = (
+ 1.0 - label_smoothing
+ ) * labels + label_smoothing / num_classes
# Numerically stable implementation of BCE loss.
# This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits().
positive_logits = logits >= 0
relu_logits = torch.where(positive_logits, logits, 0)
abs_logits = torch.where(positive_logits, logits, -logits)
- losses = relu_logits - (logits * smoothed_labels) + (
- torch.log(1 + torch.exp(-abs_logits)))
- return torch.where(mask.to(torch.bool), losses, 0.)
+ losses = (
+ relu_logits
+ - (logits * smoothed_labels)
+ + (torch.log(1 + torch.exp(-abs_logits)))
+ )
+ return torch.where(mask.to(torch.bool), losses, 0.0)
def _eval_metric(self, labels, logits, masks):
loss = self.loss_fn(labels, logits, masks)
return metrics.EvalMetrics.single_from_model_output(
- loss=loss['per_example'].cpu().numpy(),
- logits=logits.cpu().numpy(),
- labels=labels.cpu().numpy(),
- mask=masks.cpu().numpy())
+ loss=loss['per_example'].cpu().numpy(),
+ logits=logits.cpu().numpy(),
+ labels=labels.cpu().numpy(),
+ mask=masks.cpu().numpy(),
+ )
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
return {k: float(v) for k, v in total_metrics.compute().items()}
class OgbgGeluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -260,7 +274,6 @@ def test_target_value(self) -> float:
class OgbgSiluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -276,7 +289,6 @@ def test_target_value(self) -> float:
class OgbgModelSizeWorkload(OgbgWorkload):
-
@property
def hidden_dims(self) -> Tuple[int]:
return (256, 256)
diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py
index 971e7f0f6..8717e46d6 100644
--- a/algoperf/workloads/ogbg/workload.py
+++ b/algoperf/workloads/ogbg/workload.py
@@ -9,12 +9,10 @@
from algoperf import random_utils as prng
from algoperf import spec
-from algoperf.workloads.ogbg import input_pipeline
-from algoperf.workloads.ogbg import metrics
+from algoperf.workloads.ogbg import input_pipeline, metrics
class BaseOgbgWorkload(spec.Workload):
-
_num_outputs: int = 128
@property
@@ -40,8 +38,10 @@ def num_message_passing_steps(self) -> int:
return 5
def has_reached_validation_target(self, eval_result: float) -> bool:
- return eval_result[
- 'validation/mean_average_precision'] > self.validation_target_value
+ return (
+ eval_result['validation/mean_average_precision']
+ > self.validation_target_value
+ )
@property
def validation_target_value(self) -> float:
@@ -94,15 +94,16 @@ def max_allowed_runtime_sec(self) -> int:
def eval_period_time_sec(self) -> int:
return 4 * 60
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int):
- dataset_iter = input_pipeline.get_dataset_iter(split,
- data_rng,
- data_dir,
- global_batch_size)
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ ):
+ dataset_iter = input_pipeline.get_dataset_iter(
+ split, data_rng, data_dir, global_batch_size
+ )
if split != 'train':
# Note that this stores the entire val dataset in memory.
dataset_iter = itertools.cycle(dataset_iter)
@@ -111,11 +112,12 @@ def _build_input_queue(self,
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -123,19 +125,20 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = self._binary_cross_entropy_with_mask(
- labels=label_batch,
- logits=logits_batch,
- mask=mask_batch,
- label_smoothing=label_smoothing)
+ labels=label_batch,
+ logits=logits_batch,
+ mask=mask_batch,
+ label_smoothing=label_smoothing,
+ )
if mask_batch is not None:
n_valid_examples = mask_batch.sum()
else:
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@property
@@ -145,39 +148,45 @@ def step_hint(self) -> int:
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
- def _eval_batch(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> metrics.EvalMetrics:
+ def _eval_batch(
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> metrics.EvalMetrics:
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
return self._eval_metric(batch['targets'], logits, batch['weights'])
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng, split, data_dir, global_batch_size=global_batch_size)
+ data_rng, split, data_dir, global_batch_size=global_batch_size
+ )
total_metrics = None
num_eval_steps = int(math.ceil(float(num_examples) / global_batch_size))
@@ -186,8 +195,10 @@ def _eval_model_on_split(self,
batch = next(self._eval_iters[split])
batch_metrics = self._eval_batch(params, batch, model_state, model_rng)
total_metrics = (
- batch_metrics
- if total_metrics is None else total_metrics.merge(batch_metrics))
+ batch_metrics
+ if total_metrics is None
+ else total_metrics.merge(batch_metrics)
+ )
if total_metrics is None:
return {}
return self._normalize_eval_metrics(num_examples, total_metrics)
diff --git a/algoperf/workloads/utils.py b/algoperf/workloads/utils.py
index 7719f91fb..920c3cf46 100644
--- a/algoperf/workloads/utils.py
+++ b/algoperf/workloads/utils.py
@@ -5,10 +5,12 @@
def print_jax_model_summary(model, fake_inputs):
"""Prints a summary of the jax module."""
tabulate_fn = nn.tabulate(
- model,
- jax.random.PRNGKey(0),
- console_kwargs={
- 'force_terminal': False, 'force_jupyter': False, 'width': 240
- },
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={
+ 'force_terminal': False,
+ 'force_jupyter': False,
+ 'width': 240,
+ },
)
print(tabulate_fn(fake_inputs, train=False))
diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py
index ad314a7d3..6e29b1b83 100644
--- a/algoperf/workloads/wmt/bleu.py
+++ b/algoperf/workloads/wmt/bleu.py
@@ -5,19 +5,17 @@
https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py.
"""
-from collections import Counter
-from collections import namedtuple
-from itertools import zip_longest
-import logging
import math
import re
import sys
-from typing import List, Sequence
import unicodedata
+from collections import Counter, namedtuple
+from itertools import zip_longest
+from typing import List, Sequence
-from absl import logging
import torch
import torch.distributed as dist
+from absl import logging
from algoperf.pytorch_utils import pytorch_setup
@@ -30,11 +28,11 @@
def my_log(num):
"""
- Floors the log function
+ Floors the log function
- :param num: the number
- :return: log(num) floored to a very low number
- """
+ :param num: the number
+ :return: log(num) floored to a very low number
+ """
if num == 0.0:
return -9999999999
@@ -43,12 +41,12 @@ def my_log(num):
def tokenize_13a(line):
"""
- Tokenizes an input line using a relatively minimal tokenization that is
- however equivalent to mteval-v13a, used by WMT.
+ Tokenizes an input line using a relatively minimal tokenization that is
+ however equivalent to mteval-v13a, used by WMT.
- :param line: a segment to tokenize
- :return: the tokenized line
- """
+ :param line: a segment to tokenize
+ :return: the tokenized line
+ """
norm = line
@@ -62,14 +60,17 @@ def tokenize_13a(line):
norm = norm.replace('>', '>')
# language-dependent part (assuming Western languages):
- norm = " {} ".format(norm)
+ norm = ' {} '.format(norm)
norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm)
- norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ',
- norm) # tokenize period and comma unless preceded by a digit
- norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2',
- norm) # tokenize period and comma unless followed by a digit
- norm = re.sub(r'([0-9])(-)', '\\1 \\2 ',
- norm) # tokenize dash when preceded by a digit
+ norm = re.sub(
+ r'([^0-9])([\.,])', '\\1 \\2 ', norm
+ ) # tokenize period and comma unless preceded by a digit
+ norm = re.sub(
+ r'([\.,])([^0-9])', ' \\1 \\2', norm
+ ) # tokenize period and comma unless followed by a digit
+ norm = re.sub(
+ r'([0-9])(-)', '\\1 \\2 ', norm
+ ) # tokenize dash when preceded by a digit
norm = re.sub(r'\s+', ' ', norm) # one space only between words
norm = re.sub(r'^\s+', '', norm) # no leading space
norm = re.sub(r'\s+$', '', norm) # no trailing space
@@ -80,14 +81,15 @@ def tokenize_13a(line):
class UnicodeRegex:
"""Ad-hoc hack to recognize all punctuation and symbols.
- without depending on https://pypi.python.org/pypi/regex/."""
+ without depending on https://pypi.python.org/pypi/regex/."""
@staticmethod
def _property_chars(prefix):
return ''.join(
- chr(x)
- for x in range(sys.maxunicode)
- if unicodedata.category(chr(x)).startswith(prefix))
+ chr(x)
+ for x in range(sys.maxunicode)
+ if unicodedata.category(chr(x)).startswith(prefix)
+ )
punctuation = _property_chars('P')
nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])')
@@ -98,27 +100,27 @@ def _property_chars(prefix):
def tokenize_v14_international(string):
r"""Tokenize a string following the official BLEU implementation.
- See
- https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
- In our case, the input string is expected to be just one line
- and no HTML entities de-escaping is needed.
- So we just tokenize on punctuation and symbols,
- except when a punctuation is preceded and followed by a digit
- (e.g. a comma/dot as a thousand/decimal separator).
-
- Note that a number (e.g., a year) followed by a dot at the end of sentence
- is NOT tokenized,
- i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
- does not match this case (unless we add a space after each sentence).
- However, this error is already in the original mteval-v14.pl
- and we want to be consistent with it.
- The error is not present in the non-international version,
- which uses,
- `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python).
-
- :param string: the input string
- :return: a list of tokens
- """
+ See
+ https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
+ In our case, the input string is expected to be just one line
+ and no HTML entities de-escaping is needed.
+ So we just tokenize on punctuation and symbols,
+ except when a punctuation is preceded and followed by a digit
+ (e.g. a comma/dot as a thousand/decimal separator).
+
+ Note that a number (e.g., a year) followed by a dot at the end of sentence
+ is NOT tokenized,
+ i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
+ does not match this case (unless we add a space after each sentence).
+ However, this error is already in the original mteval-v14.pl
+ and we want to be consistent with it.
+ The error is not present in the non-international version,
+ which uses,
+ `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python).
+
+ :param string: the input string
+ :return: a list of tokens
+ """
string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string)
string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string)
string = UnicodeRegex.symbol_re.sub(r' \1 ', string)
@@ -127,94 +129,94 @@ def tokenize_v14_international(string):
def tokenize_zh(sentence):
"""MIT License
- Copyright (c) 2017 - Shujian Huang
-
- Permission is hereby granted, free of charge, to any person obtaining
- a copy of this software and associated documentation files
- (the "Software"), to deal in the Software without restriction, including
- without limitation the rights to use, copy, modify, merge, publish,
- distribute, sublicense, and/or sell copies of the Software, and to
- permit persons to whom the Software is furnished to do so, subject to the
- following conditions:
-
- The above copyright notice and this permission notice shall be included
- in all copies or substantial portions of the Software.
-
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
- OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
- DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
- OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
- USE OR OTHER DEALINGS IN THE SOFTWARE.
-
- The tokenization of Chinese text in this script contains two steps:
- separate each Chinese characters (by utf-8 encoding);
- tokenize the non Chinese part (following the mteval script).
- Author: Shujian Huang huangsj@nju.edu.cn
-
- :param sentence: input sentence
- :return: tokenized sentence
- """
+ Copyright (c) 2017 - Shujian Huang
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files
+ (the "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to the
+ following conditions:
+
+ The above copyright notice and this permission notice shall be included
+ in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
+ USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+ The tokenization of Chinese text in this script contains two steps:
+ separate each Chinese characters (by utf-8 encoding);
+ tokenize the non Chinese part (following the mteval script).
+ Author: Shujian Huang huangsj@nju.edu.cn
+
+ :param sentence: input sentence
+ :return: tokenized sentence
+ """
def is_chinese_char(uchar):
"""
- :param uchar: input char in unicode
- :return: whether the input char is a Chinese character.
- """
- if "\u3400" <= uchar <= "\u4db5":
+ :param uchar: input char in unicode
+ :return: whether the input char is a Chinese character.
+ """
+ if '\u3400' <= uchar <= '\u4db5':
return True
- elif "\u4e00" <= uchar <= "\u9fa5":
+ elif '\u4e00' <= uchar <= '\u9fa5':
return True
- elif "\u9fa6" <= uchar <= "\u9fbb":
+ elif '\u9fa6' <= uchar <= '\u9fbb':
return True
- elif "\uf900" <= uchar <= "\ufa2d":
+ elif '\uf900' <= uchar <= '\ufa2d':
return True
- elif "\ufa30" <= uchar <= "\ufa6a":
+ elif '\ufa30' <= uchar <= '\ufa6a':
return True
- elif "\ufa70" <= uchar <= "\ufad9":
+ elif '\ufa70' <= uchar <= '\ufad9':
return True
- elif "\u20000" <= uchar <= "\u2a6d6":
+ elif '\u20000' <= uchar <= '\u2a6d6':
return True
- elif "\u2f800" <= uchar <= "\u2fa1d":
+ elif '\u2f800' <= uchar <= '\u2fa1d':
return True
- elif "\uff00" <= uchar <= "\uffef":
+ elif '\uff00' <= uchar <= '\uffef':
return True
- elif "\u2e80" <= uchar <= "\u2eff":
+ elif '\u2e80' <= uchar <= '\u2eff':
return True
- elif "\u3000" <= uchar <= "\u303f":
+ elif '\u3000' <= uchar <= '\u303f':
return True
- elif "\u31c0" <= uchar <= "\u31ef":
+ elif '\u31c0' <= uchar <= '\u31ef':
return True
- elif "\u2f00" <= uchar <= "\u2fdf":
+ elif '\u2f00' <= uchar <= '\u2fdf':
return True
- elif "\u2ff0" <= uchar <= "\u2fff":
+ elif '\u2ff0' <= uchar <= '\u2fff':
return True
- elif "\u3100" <= uchar <= "\u312f":
+ elif '\u3100' <= uchar <= '\u312f':
return True
- elif "\u31a0" <= uchar <= "\u31bf":
+ elif '\u31a0' <= uchar <= '\u31bf':
return True
- elif "\ufe10" <= uchar <= "\ufe1f":
+ elif '\ufe10' <= uchar <= '\ufe1f':
return True
- elif "\ufe30" <= uchar <= "\ufe4f":
+ elif '\ufe30' <= uchar <= '\ufe4f':
return True
- elif "\u2600" <= uchar <= "\u26ff":
+ elif '\u2600' <= uchar <= '\u26ff':
return True
- elif "\u2700" <= uchar <= "\u27bf":
+ elif '\u2700' <= uchar <= '\u27bf':
return True
- elif "\u3200" <= uchar <= "\u32ff":
+ elif '\u3200' <= uchar <= '\u32ff':
return True
- elif "\u3300" <= uchar <= "\u33ff":
+ elif '\u3300' <= uchar <= '\u33ff':
return True
return False
sentence = sentence.strip()
- sentence_in_chars = ""
+ sentence_in_chars = ''
for char in sentence:
if is_chinese_char(char):
- sentence_in_chars += " "
+ sentence_in_chars += ' '
sentence_in_chars += char
- sentence_in_chars += " "
+ sentence_in_chars += ' '
else:
sentence_in_chars += char
sentence = sentence_in_chars
@@ -245,10 +247,10 @@ def is_chinese_char(uchar):
TOKENIZERS = {
- '13a': tokenize_13a,
- 'intl': tokenize_v14_international,
- 'zh': tokenize_zh,
- 'none': lambda x: x,
+ '13a': tokenize_13a,
+ 'intl': tokenize_v14_international,
+ 'zh': tokenize_zh,
+ 'none': lambda x: x,
}
DEFAULT_TOKENIZER = '13a'
@@ -256,16 +258,16 @@ def is_chinese_char(uchar):
def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter:
"""Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens.
- :param line: a segment containing a sequence of words
- :param max_order: collect n-grams from 1<=n<=max
- :return: a dictionary containing ngrams and counts
- """
+ :param line: a segment containing a sequence of words
+ :param max_order: collect n-grams from 1<=n<=max
+ :return: a dictionary containing ngrams and counts
+ """
ngrams = Counter()
tokens = line.split()
for n in range(min_order, max_order + 1):
for i in range(0, len(tokens) - n + 1):
- ngram = ' '.join(tokens[i:i + n])
+ ngram = ' '.join(tokens[i : i + n])
ngrams[ngram] += 1
return ngrams
@@ -293,41 +295,44 @@ def ref_stats(output, refs):
return ngrams, closest_diff, closest_len
-BLEU = namedtuple('BLE',
- 'score, counts, totals, precisions, bp, sys_len, ref_len')
+BLEU = namedtuple(
+ 'BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len'
+)
-def compute_bleu(correct: List[int],
- total: List[int],
- sys_len: int,
- ref_len: int,
- smooth_method='none',
- smooth_value=SMOOTH_VALUE_DEFAULT,
- use_effective_order=False) -> BLEU:
+def compute_bleu(
+ correct: List[int],
+ total: List[int],
+ sys_len: int,
+ ref_len: int,
+ smooth_method='none',
+ smooth_value=SMOOTH_VALUE_DEFAULT,
+ use_effective_order=False,
+) -> BLEU:
"""Computes BLEU score from its sufficient statistics. Adds smoothing.
- Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques
- for Sentence-Level BLEU", Boxing Chen and Colin Cherry,
- WMT 2014: http://aclweb.org/anthology/W14-3346)
-
- - exp: NIST smoothing method (Method 3)
- - floor: Method 1
- - add-k: Method 2 (generalizing Lin and Och, 2004)
- - none: do nothing.
-
- :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER
- :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER
- :param sys_len: The cumulative system length
- :param ref_len: The cumulative reference length
- :param smooth: The smoothing method to use
- :param smooth_value: The smoothing value added, if smooth is 'floor'
- :param use_effective_order: Use effective order.
- :return: A BLEU object with the score (100-based) and other statistics.
- """
+ Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques
+ for Sentence-Level BLEU", Boxing Chen and Colin Cherry,
+ WMT 2014: http://aclweb.org/anthology/W14-3346)
+
+ - exp: NIST smoothing method (Method 3)
+ - floor: Method 1
+ - add-k: Method 2 (generalizing Lin and Och, 2004)
+ - none: do nothing.
+
+ :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER
+ :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER
+ :param sys_len: The cumulative system length
+ :param ref_len: The cumulative reference length
+ :param smooth: The smoothing method to use
+ :param smooth_value: The smoothing value added, if smooth is 'floor'
+ :param use_effective_order: Use effective order.
+ :return: A BLEU object with the score (100-based) and other statistics.
+ """
precisions = [0 for x in range(NGRAM_ORDER)]
- smooth_mteval = 1.
+ smooth_mteval = 1.0
effective_order = NGRAM_ORDER
for n in range(NGRAM_ORDER):
if smooth_method == 'add-k' and n > 1:
@@ -342,11 +347,11 @@ def compute_bleu(correct: List[int],
if correct[n] == 0:
if smooth_method == 'exp':
smooth_mteval *= 2
- precisions[n] = 100. / (smooth_mteval * total[n])
+ precisions[n] = 100.0 / (smooth_mteval * total[n])
elif smooth_method == 'floor':
- precisions[n] = 100. * smooth_value / total[n]
+ precisions[n] = 100.0 * smooth_value / total[n]
else:
- precisions[n] = 100. * correct[n] / total[n]
+ precisions[n] = 100.0 * correct[n] / total[n]
# If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU
# score is 0 (technically undefined). This is a problem for sentence-level
@@ -360,20 +365,24 @@ def compute_bleu(correct: List[int],
brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0
bleu = brevity_penalty * math.exp(
- sum(map(my_log, precisions[:effective_order])) / effective_order)
+ sum(map(my_log, precisions[:effective_order])) / effective_order
+ )
return BLEU._make(
- [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len])
-
-
-def corpus_bleu(sys_stream: Sequence[str],
- ref_streams: Sequence[str],
- smooth_method: str = 'exp',
- smooth_value: float = 0.0,
- force: bool = False,
- lowercase: bool = False,
- tokenize: str = '13a',
- use_effective_order: bool = False) -> BLEU:
+ [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]
+ )
+
+
+def corpus_bleu(
+ sys_stream: Sequence[str],
+ ref_streams: Sequence[str],
+ smooth_method: str = 'exp',
+ smooth_value: float = 0.0,
+ force: bool = False,
+ lowercase: bool = False,
+ tokenize: str = '13a',
+ use_effective_order: bool = False,
+) -> BLEU:
"""Produces BLEU scores along with its sufficient statistics from a source
against one or more references.
:param sys_stream: The system stream (a sequence of segments).
@@ -414,13 +423,16 @@ def corpus_bleu(sys_stream: Sequence[str],
tokenized_count += 1
if tokenized_count == 100:
+ logging.warning("That's 100 lines that end in a tokenized period ('.')")
+ logging.warning(
+ 'It looks like you forgot to detokenize your test '
+ 'data, which may hurt your score.'
+ )
logging.warning(
- 'That\'s 100 lines that end in a tokenized period (\'.\')')
- logging.warning('It looks like you forgot to detokenize your test '
- 'data, which may hurt your score.')
- logging.warning('If you insist your data is detokenized, '
- 'or don\'t care, you can suppress this message with '
- '\'--force\'.')
+ 'If you insist your data is detokenized, '
+ "or don't care, you can suppress this message with "
+ "'--force'."
+ )
output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines]
@@ -453,10 +465,11 @@ def corpus_bleu(sys_stream: Sequence[str],
total = total.cpu().numpy().tolist()
return compute_bleu(
- correct,
- total,
- sys_len,
- ref_len,
- smooth_method=smooth_method,
- smooth_value=smooth_value,
- use_effective_order=use_effective_order)
+ correct,
+ total,
+ sys_len,
+ ref_len,
+ smooth_method=smooth_method,
+ smooth_value=smooth_value,
+ use_effective_order=use_effective_order,
+ )
diff --git a/algoperf/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py
index d743b43b0..3d184cd78 100644
--- a/algoperf/workloads/wmt/input_pipeline.py
+++ b/algoperf/workloads/wmt/input_pipeline.py
@@ -1,4 +1,5 @@
"""Input pipeline for a WMT dataset."""
+
import functools
import os
from typing import Dict, List, Optional, Union
@@ -16,10 +17,10 @@
Features = Dict[str, tf.Tensor]
TFDS_SPLIT_NAME = {
- 'train': 'train',
- 'eval_train': 'train',
- 'validation': 'validation',
- 'test': 'test',
+ 'train': 'train',
+ 'eval_train': 'train',
+ 'validation': 'validation',
+ 'test': 'test',
}
@@ -31,9 +32,11 @@ def normalize_feature_names(ds_info, features: Features) -> Features:
return features
-def pack_dataset(dataset: tf.data.Dataset,
- key2length: Union[int, Dict[str, int]],
- keys: Optional[List[str]] = None) -> tf.data.Dataset:
+def pack_dataset(
+ dataset: tf.data.Dataset,
+ key2length: Union[int, Dict[str, int]],
+ keys: Optional[List[str]] = None,
+) -> tf.data.Dataset:
"""Creates a 'packed' version of a dataset on-the-fly.
Adapted from the mesh-tf implementation.
This is meant to replace the irritation of having to create a separate
@@ -75,7 +78,8 @@ def pack_dataset(dataset: tf.data.Dataset,
for k in keys:
if k not in shapes:
raise ValueError(
- f'Key {k} not found in dataset. Available keys are {shapes.keys()}')
+ f'Key {k} not found in dataset. Available keys are {shapes.keys()}'
+ )
if not shapes[k].is_compatible_with(tf.TensorShape([None])):
raise ValueError('Tensors to be packed must be one-dimensional.')
# make sure that the length dictionary contains all keys as well as the
@@ -88,13 +92,15 @@ def pack_dataset(dataset: tf.data.Dataset,
# trim to length
dataset = dataset.map(
- lambda x: {k: x[k][:key2length[k]] for k in keys},
- num_parallel_calls=AUTOTUNE)
+ lambda x: {k: x[k][: key2length[k]] for k in keys},
+ num_parallel_calls=AUTOTUNE,
+ )
# Setting batch_size=length ensures that the concatenated sequences (if they
# have length >=1) are sufficient to fill at least one packed example.
batch_size = max(key2length.values())
dataset = dataset.padded_batch(
- batch_size, padded_shapes={k: [-1] for k in keys})
+ batch_size, padded_shapes={k: [-1] for k in keys}
+ )
dataset = _pack_with_tf_ops(dataset, keys, key2length)
# Set the Tensor shapes correctly since they get lost in the process.
@@ -104,9 +110,9 @@ def my_fn(x):
return dataset.map(my_fn, num_parallel_calls=AUTOTUNE)
-def _pack_with_tf_ops(dataset: tf.data.Dataset,
- keys: List[str],
- key2length: Dict[str, int]) -> tf.data.Dataset:
+def _pack_with_tf_ops(
+ dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]
+) -> tf.data.Dataset:
"""Helper-function for packing a dataset which has already been batched.
Helper for pack_dataset() Uses tf.while_loop.
Args:
@@ -127,8 +133,9 @@ def write_packed_example(partial, outputs):
new_outputs = {}
for k in keys_etc:
new_outputs[k] = outputs[k].write(
- outputs[k].size(),
- tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]))
+ outputs[k].size(),
+ tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]),
+ )
return new_partial, new_outputs
def map_fn(x):
@@ -146,9 +153,11 @@ def map_fn(x):
outputs = {}
for k in keys:
outputs[k] = tf.TensorArray(
- tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
+ tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
+ )
outputs[k + '_position'] = tf.TensorArray(
- tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
+ tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
+ )
def body_fn(i, partial, outputs):
"""Body function for while_loop.
@@ -163,13 +172,15 @@ def body_fn(i, partial, outputs):
one_example = {}
for k in keys:
val = tf.cast(x[k][i], tf.int32)
- val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
+ val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
one_example[k] = val
for k in keys:
can_append = tf.logical_and(
- can_append,
- tf.less_equal(
- tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]))
+ can_append,
+ tf.less_equal(
+ tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
+ ),
+ )
def false_fn():
return write_packed_example(partial, outputs)
@@ -180,53 +191,55 @@ def true_fn():
partial, outputs = tf.cond(can_append, true_fn, false_fn)
new_partial = {}
for k in keys:
- new_seq = one_example[k][:key2length[k]]
+ new_seq = one_example[k][: key2length[k]]
new_seq_len = tf.size(new_seq)
new_partial[k] = tf.concat([partial[k], new_seq], 0)
new_partial[k + '_position'] = tf.concat(
- [partial[k + '_position'], tf.range(new_seq_len)], 0)
+ [partial[k + '_position'], tf.range(new_seq_len)], 0
+ )
partial = new_partial
return i + 1, partial, outputs
# For loop over all examples in the batch.
i, partial, outputs = tf.while_loop(
- cond=lambda *_: True,
- body=body_fn,
- loop_vars=(i, partial, outputs),
- shape_invariants=(
- tf.TensorShape([]),
- {k: tf.TensorShape([None]) for k in keys_etc},
- {k: tf.TensorShape(None) for k in keys_etc},
- ),
- maximum_iterations=dynamic_batch_size)
+ cond=lambda *_: True,
+ body=body_fn,
+ loop_vars=(i, partial, outputs),
+ shape_invariants=(
+ tf.TensorShape([]),
+ {k: tf.TensorShape([None]) for k in keys_etc},
+ {k: tf.TensorShape(None) for k in keys_etc},
+ ),
+ maximum_iterations=dynamic_batch_size,
+ )
_, outputs = write_packed_example(partial, outputs)
packed = {k: outputs[k].stack() for k in keys_etc}
for k in keys:
- packed[k + '_segmentation'] = (
- tf.cumsum(
- tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) *
- tf.cast(tf.not_equal(packed[k], 0), tf.int32))
+ packed[k + '_segmentation'] = tf.cumsum(
+ tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1
+ ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)
return packed
dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE)
return dataset.unbatch()
-def preprocess_wmt_data(dataset: tf.data.Dataset,
- data_rng,
- train: bool,
- shuffle: bool,
- shuffle_buffer_size: int = 1024,
- max_length: int = 256,
- global_batch_size: int = 128):
+def preprocess_wmt_data(
+ dataset: tf.data.Dataset,
+ data_rng,
+ train: bool,
+ shuffle: bool,
+ shuffle_buffer_size: int = 1024,
+ max_length: int = 256,
+ global_batch_size: int = 128,
+):
"""Shuffle and batch/pack the given dataset."""
def length_filter(max_len):
-
def filter_fn(x):
source, target = x['inputs'], x['targets']
- l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
- return tf.less(l, max_len + 1)
+ length = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
+ return tf.less(length, max_len + 1)
return filter_fn
@@ -242,24 +255,27 @@ def filter_fn(x):
dataset = dataset.batch(global_batch_size, drop_remainder=train)
else: # simple (static-shape) padded batching
dataset = dataset.padded_batch(
- global_batch_size,
- padded_shapes={'inputs': max_length, 'targets': max_length},
- padding_values={'inputs': 0, 'targets': 0},
- drop_remainder=False)
+ global_batch_size,
+ padded_shapes={'inputs': max_length, 'targets': max_length},
+ padding_values={'inputs': 0, 'targets': 0},
+ drop_remainder=False,
+ )
dataset = dataset.prefetch(AUTOTUNE)
return dataset
-def get_wmt_dataset(data_rng,
- split: str,
- data_dir: str,
- is_training: bool,
- vocab_size: int,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False,
- vocab_path: Optional[str] = None):
+def get_wmt_dataset(
+ data_rng,
+ split: str,
+ data_dir: str,
+ is_training: bool,
+ vocab_size: int,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+ vocab_path: Optional[str] = None,
+):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model')
@@ -271,7 +287,8 @@ def get_wmt_dataset(data_rng,
dataset_builder = tfds.builder(ds_name, data_dir=data_dir)
ds = dataset_builder.as_dataset(
- split=TFDS_SPLIT_NAME[split], shuffle_files=False)
+ split=TFDS_SPLIT_NAME[split], shuffle_files=False
+ )
# Avoid creating too many threads when using PyTorch DDP.
if RANK != 0:
@@ -280,8 +297,9 @@ def get_wmt_dataset(data_rng,
ds = ds.with_options(options)
ds = ds.map(
- functools.partial(normalize_feature_names, dataset_builder.info),
- num_parallel_calls=AUTOTUNE)
+ functools.partial(normalize_feature_names, dataset_builder.info),
+ num_parallel_calls=AUTOTUNE,
+ )
# Load tf-text SentencePiece tokenizer.
sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path)
@@ -289,12 +307,13 @@ def get_wmt_dataset(data_rng,
shuffle = split in ['train', 'eval_train']
ds = preprocess_wmt_data(
- ds,
- data_rng,
- train=is_training,
- shuffle=shuffle,
- global_batch_size=global_batch_size,
- max_length=256)
+ ds,
+ data_rng,
+ train=is_training,
+ shuffle=shuffle,
+ global_batch_size=global_batch_size,
+ max_length=256,
+ )
if num_batches:
ds = ds.take(num_batches)
@@ -303,9 +322,10 @@ def get_wmt_dataset(data_rng,
ds = ds.repeat()
ds = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
return ds, sp_tokenizer
diff --git a/algoperf/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py
index 1f001e619..273e11dfa 100644
--- a/algoperf/workloads/wmt/tokenizer.py
+++ b/algoperf/workloads/wmt/tokenizer.py
@@ -9,19 +9,19 @@
import time
from typing import Any, Dict, Iterable, Tuple
-from absl import logging
import jax
-from sentencepiece import SentencePieceTrainer
import tensorflow as tf
import tensorflow_text as tftxt
+from absl import logging
+from sentencepiece import SentencePieceTrainer
Features = Dict[str, tf.Tensor]
def _dump_chars_to_textfile(
- dataset: tf.data.Dataset,
- maxchars: int = int(1e7),
- data_keys=('inputs', 'targets')
+ dataset: tf.data.Dataset,
+ maxchars: int = int(1e7),
+ data_keys=('inputs', 'targets'),
) -> Tuple[str, int]:
"""Write part of a TFDS sentence dataset to lines in a text file.
@@ -36,7 +36,8 @@ def _dump_chars_to_textfile(
char_count = 0
ds_iter = dataset.as_numpy_iterator()
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/ds_chars') as outfp:
+ delete=False, prefix='/tmp/ds_chars'
+ ) as outfp:
while char_count < maxchars:
example = next(ds_iter)
for k in data_keys:
@@ -46,14 +47,16 @@ def _dump_chars_to_textfile(
return outfp.name, char_count
-def _train_sentencepiece(dataset: tf.data.Dataset,
- *,
- vocab_size: int,
- maxchars: int = int(1e7),
- model_path: str,
- model_type: str = 'unigram',
- character_coverage: float = 1.0,
- data_keys=('inputs', 'targets')):
+def _train_sentencepiece(
+ dataset: tf.data.Dataset,
+ *,
+ vocab_size: int,
+ maxchars: int = int(1e7),
+ model_path: str,
+ model_type: str = 'unigram',
+ character_coverage: float = 1.0,
+ data_keys=('inputs', 'targets'),
+):
"""Train SentencePiece tokenizer from subset of tf dataset.
Args:
@@ -75,17 +78,21 @@ def _train_sentencepiece(dataset: tf.data.Dataset,
else:
abs_model_path = os.path.abspath(os.path.expanduser(model_path))
fname, _ = _dump_chars_to_textfile(
- dataset, maxchars=maxchars, data_keys=data_keys)
+ dataset, maxchars=maxchars, data_keys=data_keys
+ )
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/sp_tmp') as model_fp:
+ delete=False, prefix='/tmp/sp_tmp'
+ ) as model_fp:
pass # we just want a prefix'd tmp-filename
- argstr = ' '.join([
+ argstr = ' '.join(
+ [
f'--input={fname}',
f'--vocab_size={vocab_size}',
f'--character_coverage={character_coverage}',
f'--model_prefix={model_fp.name}',
f'--model_type={model_type}',
- ])
+ ]
+ )
SentencePieceTrainer.Train(argstr)
if jax.process_index() == 0:
# Use an intermediate filename that is renamed to the target name to address
@@ -100,32 +107,38 @@ def _train_sentencepiece(dataset: tf.data.Dataset,
time.sleep(1)
-def _load_sentencepiece_tokenizer(model_path: str,
- add_bos: bool = False,
- add_eos: bool = True,
- reverse: bool = False):
+def _load_sentencepiece_tokenizer(
+ model_path: str,
+ add_bos: bool = False,
+ add_eos: bool = True,
+ reverse: bool = False,
+):
"""Load a tf-text SentencePiece tokenizer from given model filepath."""
with tf.io.gfile.GFile(model_path, 'rb') as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
- model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse)
+ model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse
+ )
return sp_tokenizer
-def train_tokenizer(dataset: tf.data.Dataset,
- *,
- vocab_path: str,
- vocab_size: int,
- max_corpus_chars: int,
- data_keys: Tuple[str, str] = ('inputs', 'targets')):
+def train_tokenizer(
+ dataset: tf.data.Dataset,
+ *,
+ vocab_path: str,
+ vocab_size: int,
+ max_corpus_chars: int,
+ data_keys: Tuple[str, str] = ('inputs', 'targets'),
+):
"""Trains a tokenizer from `dataset`."""
logging.info('Building SentencePiece vocab from data.')
_train_sentencepiece(
- dataset,
- vocab_size=vocab_size,
- maxchars=max_corpus_chars,
- model_path=vocab_path,
- data_keys=data_keys)
+ dataset,
+ vocab_size=vocab_size,
+ maxchars=max_corpus_chars,
+ model_path=vocab_path,
+ data_keys=data_keys,
+ )
def load_tokenizer(vocab_path: str):
@@ -135,7 +148,6 @@ def load_tokenizer(vocab_path: str):
@dataclasses.dataclass
class TokenizeOp:
-
sp_tokenizer: Any
data_keys: Iterable[str] = ('inputs', 'targets')
diff --git a/algoperf/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py
index dfead5918..196d9175e 100644
--- a/algoperf/workloads/wmt/wmt_jax/decode.py
+++ b/algoperf/workloads/wmt/wmt_jax/decode.py
@@ -7,9 +7,9 @@
import flax
import jax
-from jax import lax
import jax.numpy as jnp
import numpy as np
+from jax import lax
# Constants
# We assume the default End-of-Sentence token id is 2 (SentencePiece).
@@ -78,8 +78,9 @@ def gather_beams(nested, beam_indices, batch_size, new_beam_size):
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
batch_indices = jnp.reshape(
- jnp.arange(batch_size * new_beam_size) // new_beam_size,
- (batch_size, new_beam_size))
+ jnp.arange(batch_size * new_beam_size) // new_beam_size,
+ (batch_size, new_beam_size),
+ )
def gather_fn(x):
if x.ndim < 2: # ignore scalars (e.g. cache index)
@@ -114,6 +115,7 @@ def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
@flax.struct.dataclass
class BeamState:
"""Holds beam search state data."""
+
# The position of the decoding loop in the length dimension.
cur_index: jax.Array # scalar int32: current decoded length index
# The active sequence log probabilities and finished sequence scores.
@@ -133,7 +135,8 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
"""Initializes the beam search state data structure."""
cur_index0 = jnp.array(0)
live_logprobs0 = jnp.tile(
- jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1])
+ jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]
+ )
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
@@ -141,25 +144,28 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
- cur_index=cur_index0,
- live_logprobs=live_logprobs0,
- finished_scores=finished_scores0,
- live_seqs=live_seqs0,
- finished_seqs=finished_seqs0,
- finished_flags=finished_flags0,
- cache=beam_cache0)
+ cur_index=cur_index0,
+ live_logprobs=live_logprobs0,
+ finished_scores=finished_scores0,
+ live_seqs=live_seqs0,
+ finished_seqs=finished_seqs0,
+ finished_flags=finished_flags0,
+ cache=beam_cache0,
+ )
# Beam search routine:
-def beam_search(inputs,
- cache,
- tokens_to_logits,
- beam_size=4,
- alpha=0.6,
- eos_id=EOS_ID,
- max_decode_len=None):
+def beam_search(
+ inputs,
+ cache,
+ tokens_to_logits,
+ beam_size=4,
+ alpha=0.6,
+ eos_id=EOS_ID,
+ max_decode_len=None,
+):
"""Beam search for transformer machine translation.
Args:
@@ -185,10 +191,9 @@ def beam_search(inputs,
end_marker = jnp.array(eos_id)
# initialize beam search state
- beam_search_init_state = beam_init(batch_size,
- beam_size,
- max_decode_len,
- cache)
+ beam_search_init_state = beam_init(
+ batch_size, beam_size, max_decode_len, cache
+ )
def beam_search_loop_cond_fn(state):
"""Beam search loop termination condition."""
@@ -201,11 +206,12 @@ def beam_search_loop_cond_fn(state):
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
# Get the worst scores from finished sequences.
worst_finished_scores = jnp.min(
- state.finished_scores, axis=1, keepdims=True)
+ state.finished_scores, axis=1, keepdims=True
+ )
# Mask out scores from slots without any actual finished sequences.
- worst_finished_scores = jnp.where(state.finished_flags,
- worst_finished_scores,
- NEG_INF)
+ worst_finished_scores = jnp.where(
+ state.finished_flags, worst_finished_scores, NEG_INF
+ )
# If no best possible live score is better than current worst finished
# scores, the search cannot improve the finished set further.
search_terminated = jnp.all(worst_finished_scores > best_live_scores)
@@ -221,8 +227,10 @@ def beam_search_loop_body_fn(state):
# dimension for feeding into the model.
# --> [batch * beam, 1]
flat_ids = flatten_beam_dim(
- lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
- (batch_size, beam_size, 1)))
+ lax.dynamic_slice(
+ state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1)
+ )
+ )
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)
@@ -237,14 +245,16 @@ def beam_search_loop_body_fn(state):
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree.map(
- lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)
+ lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache
+ )
# Gather log probabilities from logits
candidate_log_probs = jax.nn.log_softmax(logits)
# Add new logprobs to existing prefix logprobs.
# --> [batch, beam, vocab]
- log_probs = (
- candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2))
+ log_probs = candidate_log_probs + jnp.expand_dims(
+ state.live_logprobs, axis=2
+ )
# We'll need the vocab size, gather it from the log probability dimension.
vocab_size = log_probs.shape[2]
@@ -264,10 +274,9 @@ def beam_search_loop_body_fn(state):
topk_beam_indices = topk_indices // vocab_size
# Gather 2*k top beams.
# --> [batch, 2*beams, length]
- topk_seq = gather_beams(state.live_seqs,
- topk_beam_indices,
- batch_size,
- beams_to_keep)
+ topk_seq = gather_beams(
+ state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
+ )
# Append the most probable 2*K token IDs to the top 2*K sequences
# Recover token id by modulo division and expand Id array for broadcasting.
@@ -275,13 +284,14 @@ def beam_search_loop_body_fn(state):
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
# Update sequences for the 2*K top-k new sequences.
# --> [batch, 2*beams, length]
- topk_seq = lax.dynamic_update_slice(topk_seq,
- topk_ids, (0, 0, state.cur_index + 1))
+ topk_seq = lax.dynamic_update_slice(
+ topk_seq, topk_ids, (0, 0, state.cur_index + 1)
+ )
# Update LIVE (in-progress) sequences:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
- newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
+ newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
@@ -292,22 +302,20 @@ def beam_search_loop_body_fn(state):
new_topk_indices = jnp.flip(new_topk_indices, axis=1)
# Gather the top k beams (from top 2*k beams).
# --> [batch, beams, length], [batch, beams]
- top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs],
- new_topk_indices,
- batch_size, beam_size)
+ top_alive_seq, top_alive_log_probs = gather_beams(
+ [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
+ )
# Determine the top k beam indices from the original set of all beams.
# --> [batch, beams]
- top_alive_indices = gather_beams(topk_beam_indices,
- new_topk_indices,
- batch_size,
- beam_size)
+ top_alive_indices = gather_beams(
+ topk_beam_indices, new_topk_indices, batch_size, beam_size
+ )
# With these, gather the top k beam-associated caches.
# --> {[batch, beams, ...], ...}
- top_alive_cache = gather_beams(new_cache,
- top_alive_indices,
- batch_size,
- beam_size)
+ top_alive_cache = gather_beams(
+ new_cache, top_alive_indices, batch_size, beam_size
+ )
# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
@@ -320,42 +328,54 @@ def beam_search_loop_body_fn(state):
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams.
finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
- [state.finished_seqs, topk_seq],
- axis=1)
+ [state.finished_seqs, topk_seq], axis=1
+ )
finished_scores = jnp.concatenate( # --> [batch, 3*beams]
- [state.finished_scores, new_scores], axis=1)
+ [state.finished_scores, new_scores], axis=1
+ )
finished_flags = jnp.concatenate( # --> [batch, 3*beams]
- [state.finished_flags, newly_finished], axis=1)
+ [state.finished_flags, newly_finished], axis=1
+ )
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
- gather_topk_beams([finished_seqs, finished_scores, finished_flags],
- finished_scores, batch_size, beam_size))
+ gather_topk_beams(
+ [finished_seqs, finished_scores, finished_flags],
+ finished_scores,
+ batch_size,
+ beam_size,
+ )
+ )
return BeamState(
- cur_index=state.cur_index + 1,
- live_logprobs=top_alive_log_probs,
- finished_scores=top_finished_scores,
- live_seqs=top_alive_seq,
- finished_seqs=top_finished_seq,
- finished_flags=top_finished_flags,
- cache=top_alive_cache)
+ cur_index=state.cur_index + 1,
+ live_logprobs=top_alive_log_probs,
+ finished_scores=top_finished_scores,
+ live_seqs=top_alive_seq,
+ finished_seqs=top_finished_seq,
+ finished_flags=top_finished_flags,
+ cache=top_alive_cache,
+ )
# Run while loop and get final beam search state.
- final_state = lax.while_loop(beam_search_loop_cond_fn,
- beam_search_loop_body_fn,
- beam_search_init_state)
+ final_state = lax.while_loop(
+ beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state
+ )
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return live sequences for that batch item.
# --> [batch]
none_finished = jnp.any(final_state.finished_flags, axis=1)
# --> [batch, beams, length]
- finished_seqs = jnp.where(none_finished[:, None, None],
- final_state.finished_seqs,
- final_state.live_seqs)
+ finished_seqs = jnp.where(
+ none_finished[:, None, None],
+ final_state.finished_seqs,
+ final_state.live_seqs,
+ )
# --> [batch, beams]
- finished_scores = jnp.where(none_finished[:, None],
- final_state.finished_scores,
- final_state.live_logprobs)
+ finished_scores = jnp.where(
+ none_finished[:, None],
+ final_state.finished_scores,
+ final_state.live_logprobs,
+ )
return finished_seqs, finished_scores
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 97fee032f..81f2ece4c 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -5,16 +5,21 @@
from typing import Any, Callable, Optional
+import jax.numpy as jnp
+import numpy as np
from flax import linen as nn
from flax import struct
from jax import lax
-import jax.numpy as jnp
-import numpy as np
+
+from algoperf.jax_utils import Dropout
+
+DROPOUT_RATE = 0.1
@struct.dataclass
class TransformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
share_embeddings: bool = True
dtype: Any = jnp.float32
vocab_size: int = 32000
@@ -26,10 +31,6 @@ class TransformerConfig:
max_len: int = 256
activation: Callable = nn.relu
glu: bool = False
- #If None, defaults to 0.1.
- dropout_rate: Optional[float] = 0.1
- #If None, defaults to 0.1.
- attention_dropout_rate: Optional[float] = 0.1
attention_temp: float = 1.0
deterministic: bool = False
decode: bool = False
@@ -44,7 +45,8 @@ def shift_right(x, axis=1):
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[axis] = (1, 0)
padded = jnp.pad(
- x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
+ x, pad_widths, mode='constant', constant_values=x.dtype.type(0)
+ )
return padded[:, :-1]
@@ -68,8 +70,8 @@ def init(key, shape, dtype=np.float32):
position = np.arange(0, max_len)[:, np.newaxis]
scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor)
- pe[:, :d_feature // 2] = np.sin(position * div_term)
- pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term)
+ pe[:, : d_feature // 2] = np.sin(position * div_term)
+ pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term)
pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
return jnp.array(pe)
@@ -83,6 +85,7 @@ class AddPositionEmbs(nn.Module):
config: TransformerConfig dataclass containing hyperparameters.
decode: whether to run in single-position autoregressive mode.
"""
+
config: TransformerConfig
decode: bool = False
@@ -103,27 +106,28 @@ def __call__(self, inputs, inputs_positions=None):
"""
cfg = self.config
# inputs.shape is (batch_size, seq_len, emb_dim)
- assert inputs.ndim == 3, ('Number of dimensions should be 3,'
- f' but it is: {inputs.ndim}')
+ assert inputs.ndim == 3, (
+ f'Number of dimensions should be 3, but it is: {inputs.ndim}'
+ )
length = inputs.shape[1]
pos_emb_shape = (1, cfg.max_len, inputs.shape[-1])
if cfg.posemb_init is None:
# Use a fixed (non-learned) sinusoidal position embedding.
- pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None,
- pos_emb_shape,
- None)
+ pos_embedding = sinusoidal_init(max_len=cfg.max_len)(
+ None, pos_emb_shape, None
+ )
else:
- pos_embedding = self.param('pos_embedding',
- cfg.posemb_init,
- pos_emb_shape)
+ pos_embedding = self.param(
+ 'pos_embedding', cfg.posemb_init, pos_emb_shape
+ )
pe = pos_embedding[:, :length, :]
# We use a cache position index for tracking decoding position.
if self.decode:
is_initialized = self.has_variable('cache', 'cache_index')
- cache_index = self.variable('cache',
- 'cache_index',
- lambda: jnp.array(0, dtype=jnp.uint32))
+ cache_index = self.variable(
+ 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32)
+ )
if is_initialized:
i = cache_index.value
cache_index.value = i + 1
@@ -144,43 +148,43 @@ class MlpBlock(nn.Module):
config: TransformerConfig dataclass containing hyperparameters.
out_dim: optionally specify out dimension.
"""
+
config: TransformerConfig
out_dim: Optional[int] = None
@nn.compact
- def __call__(self, inputs):
+ def __call__(self, inputs, dropout_rate=DROPOUT_RATE):
"""Applies Transformer MlpBlock module."""
cfg = self.config
- actual_out_dim = (
- inputs.shape[-1] if self.out_dim is None else self.out_dim)
+
+ actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- inputs)
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(inputs)
x = cfg.activation(x)
if cfg.glu:
y = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- inputs)
- x = x * y
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- output = nn.Dense(
- actual_out_dim,
+ cfg.mlp_dim,
dtype=cfg.dtype,
kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- x)
- output = nn.Dropout(rate=dropout_rate)(
- output, deterministic=cfg.deterministic)
+ bias_init=cfg.bias_init,
+ )(inputs)
+ x = x * y
+ x = Dropout(rate=dropout_rate)(
+ x, rate=dropout_rate, deterministic=cfg.deterministic
+ )
+ output = nn.Dense(
+ actual_out_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(x)
+ output = Dropout(rate=dropout_rate)(
+ output, rate=dropout_rate, deterministic=cfg.deterministic
+ )
return output
@@ -190,10 +194,11 @@ class Encoder1DBlock(nn.Module):
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
"""
+
config: TransformerConfig
@nn.compact
- def __call__(self, inputs, encoder_mask=None):
+ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
"""Applies Encoder1DBlock module.
Args:
@@ -204,39 +209,34 @@ def __call__(self, inputs, encoder_mask=None):
output after transformer encoder block.
"""
cfg = self.config
+
pre_ln = cfg.pre_ln
# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
- if cfg.attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- else:
- attention_dropout_rate = cfg.attention_dropout_rate
x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic)(
- cfg.attention_temp * x, x, mask=encoder_mask)
-
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ )(cfg.attention_temp * x, x, mask=encoder_mask)
+
+ x = Dropout(rate=dropout_rate)(
+ x, deterministic=cfg.deterministic, rate=dropout_rate
+ )
x = x + inputs
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
# MLP block.
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = MlpBlock(config=cfg)(y)
+ y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate)
return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
@@ -247,14 +247,18 @@ class EncoderDecoder1DBlock(nn.Module):
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
"""
+
config: TransformerConfig
@nn.compact
- def __call__(self,
- targets,
- encoded,
- decoder_mask=None,
- encoder_decoder_mask=None):
+ def __call__(
+ self,
+ targets,
+ encoded,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies EncoderDecoder1DBlock module.
Args:
@@ -267,33 +271,29 @@ def __call__(self,
output after transformer encoder-decoder block.
"""
cfg = self.config
+
pre_ln = cfg.pre_ln
# Decoder block.
assert targets.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
- if cfg.attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- else:
- attention_dropout_rate = cfg.attention_dropout_rate
x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic,
- decode=cfg.decode)(
- cfg.attention_temp * x, x, mask=decoder_mask)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ decode=cfg.decode,
+ )(cfg.attention_temp * x, x, mask=decoder_mask)
+
+ x = Dropout(rate=dropout_rate)(
+ x, deterministic=cfg.deterministic, rate=dropout_rate
+ )
x = x + targets
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -301,25 +301,27 @@ def __call__(self,
# Encoder-Decoder block.
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
y = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic)(
- cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
-
- y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
+
+ y = Dropout(rate=dropout_rate)(
+ y, deterministic=cfg.deterministic, rate=dropout_rate
+ )
y = y + x
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)
# MLP block.
z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
- z = MlpBlock(config=cfg)(z)
+ z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate)
return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
@@ -331,11 +333,18 @@ class Encoder(nn.Module):
config: TransformerConfig dataclass containing hyperparameters.
shared_embedding: a shared embedding layer to use.
"""
+
config: TransformerConfig
shared_embedding: Any = None
@nn.compact
- def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
+ def __call__(
+ self,
+ inputs,
+ inputs_positions=None,
+ encoder_mask=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer model on the inputs.
Args:
@@ -347,37 +356,40 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
output of a transformer encoder.
"""
cfg = self.config
+
assert inputs.ndim == 2 # (batch, len)
# Input Embedding
if self.shared_embedding is None:
input_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
else:
input_embed = self.shared_embedding
x = inputs.astype('int32')
x = input_embed(x)
- x = AddPositionEmbs(
- config=cfg, decode=False, name='posembed_input')(
- x, inputs_positions=inputs_positions)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = AddPositionEmbs(config=cfg, decode=False, name='posembed_input')(
+ x, inputs_positions=inputs_positions
+ )
+ x = Dropout(rate=dropout_rate)(
+ x, deterministic=cfg.deterministic, rate=dropout_rate
+ )
x = x.astype(cfg.dtype)
# Input Encoder
for lyr in range(cfg.num_layers):
- x = Encoder1DBlock(
- config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)
+ x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')(
+ x, encoder_mask, dropout_rate
+ )
encoded = (
- nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
- if cfg.pre_ln else x)
+ nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
+ if cfg.pre_ln
+ else x
+ )
return encoded
@@ -389,16 +401,20 @@ class Decoder(nn.Module):
config: TransformerConfig dataclass containing hyperparameters.
shared_embedding: a shared embedding layer to use.
"""
+
config: TransformerConfig
shared_embedding: Any = None
@nn.compact
- def __call__(self,
- encoded,
- targets,
- targets_positions=None,
- decoder_mask=None,
- encoder_decoder_mask=None):
+ def __call__(
+ self,
+ encoded,
+ targets,
+ targets_positions=None,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer model on the inputs.
Args:
@@ -419,9 +435,10 @@ def __call__(self,
# Target Embedding
if self.shared_embedding is None:
output_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
else:
output_embed = self.shared_embedding
@@ -429,28 +446,29 @@ def __call__(self,
if not cfg.decode:
y = shift_right(y)
y = output_embed(y)
- y = AddPositionEmbs(
- config=cfg, decode=cfg.decode, name='posembed_output')(
- y, inputs_positions=targets_positions)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
+ y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')(
+ y, inputs_positions=targets_positions
+ )
+ y = Dropout(rate=dropout_rate)(
+ y, deterministic=cfg.deterministic, rate=dropout_rate
+ )
y = y.astype(cfg.dtype)
# Target-Input Decoder
for lyr in range(cfg.num_layers):
- y = EncoderDecoder1DBlock(
- config=cfg, name=f'encoderdecoderblock_{lyr}')(
- y,
- encoded,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask)
+ y = EncoderDecoder1DBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')(
+ y,
+ encoded,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=dropout_rate,
+ )
y = (
- nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
- if cfg.pre_ln else y)
+ nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
+ if cfg.pre_ln
+ else y
+ )
# Use the transpose of embedding matrix for logit transform.
logits = output_embed.attend(y.astype(jnp.float32))
@@ -465,6 +483,7 @@ class Transformer(nn.Module):
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
"""
+
config: TransformerConfig
def setup(self):
@@ -473,18 +492,26 @@ def setup(self):
if cfg.share_embeddings:
if cfg.vocab_size is not None:
assert cfg.vocab_size == cfg.vocab_size, (
- "can't share embedding with different vocab sizes.")
+ "can't share embedding with different vocab sizes."
+ )
self.shared_embedding = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
else:
self.shared_embedding = None
self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
- def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
+ def encode(
+ self,
+ inputs,
+ inputs_positions=None,
+ inputs_segmentation=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer encoder-branch on the inputs.
Args:
@@ -498,27 +525,33 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
cfg = self.config
# Make padding attention mask.
encoder_mask = nn.make_attention_mask(
- inputs > 0, inputs > 0, dtype=cfg.dtype)
+ inputs > 0, inputs > 0, dtype=cfg.dtype
+ )
# Add segmentation block-diagonal attention mask if using segmented data.
if inputs_segmentation is not None:
encoder_mask = nn.combine_masks(
- encoder_mask,
- nn.make_attention_mask(
- inputs_segmentation,
- inputs_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
+ encoder_mask,
+ nn.make_attention_mask(
+ inputs_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype
+ ),
+ )
return self.encoder(
- inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
+ inputs,
+ inputs_positions=inputs_positions,
+ encoder_mask=encoder_mask,
+ dropout_rate=dropout_rate,
+ )
def decode(
- self,
- encoded,
- inputs, # only needed for masks
- targets,
- targets_positions=None,
- inputs_segmentation=None,
- targets_segmentation=None):
+ self,
+ encoded,
+ inputs, # only needed for masks
+ targets,
+ targets_positions=None,
+ inputs_segmentation=None,
+ targets_segmentation=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
@@ -538,45 +571,51 @@ def decode(
if cfg.decode:
decoder_mask = None
encoder_decoder_mask = nn.make_attention_mask(
- jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype)
+ jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype
+ )
else:
decoder_mask = nn.combine_masks(
- nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype),
- nn.make_causal_mask(targets, dtype=cfg.dtype))
+ nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype),
+ nn.make_causal_mask(targets, dtype=cfg.dtype),
+ )
encoder_decoder_mask = nn.make_attention_mask(
- targets > 0, inputs > 0, dtype=cfg.dtype)
+ targets > 0, inputs > 0, dtype=cfg.dtype
+ )
# Add segmentation block-diagonal attention masks if using segmented data.
if inputs_segmentation is not None:
decoder_mask = nn.combine_masks(
- decoder_mask,
- nn.make_attention_mask(
- targets_segmentation,
- targets_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
+ decoder_mask,
+ nn.make_attention_mask(
+ targets_segmentation, targets_segmentation, jnp.equal, dtype=cfg.dtype
+ ),
+ )
encoder_decoder_mask = nn.combine_masks(
- encoder_decoder_mask,
- nn.make_attention_mask(
- targets_segmentation,
- inputs_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
+ encoder_decoder_mask,
+ nn.make_attention_mask(
+ targets_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype
+ ),
+ )
logits = self.decoder(
- encoded,
- targets,
- targets_positions=targets_positions,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask)
+ encoded,
+ targets,
+ targets_positions=targets_positions,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=dropout_rate,
+ )
return logits.astype(self.config.dtype)
- def __call__(self,
- inputs,
- targets,
- inputs_positions=None,
- targets_positions=None,
- inputs_segmentation=None,
- targets_segmentation=None):
+ def __call__(
+ self,
+ inputs,
+ targets,
+ inputs_positions=None,
+ targets_positions=None,
+ inputs_segmentation=None,
+ targets_segmentation=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer model on the inputs.
Args:
@@ -591,14 +630,18 @@ def __call__(self,
logits array from full transformer.
"""
encoded = self.encode(
- inputs,
- inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation)
+ inputs,
+ inputs_positions=inputs_positions,
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate,
+ )
return self.decode(
- encoded,
- inputs, # only used for masks
- targets,
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation)
+ encoded,
+ inputs, # only used for masks
+ targets,
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentation,
+ targets_segmentation=targets_segmentation,
+ dropout_rate=dropout_rate,
+ )
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index cdfcb91df..51d8a85a7 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -1,23 +1,21 @@
"""WMT workload implemented in Jax."""
-from dataclasses import replace
import functools
+from dataclasses import replace
from typing import Any, Dict, Iterator, Optional, Tuple
-from absl import logging
-from flax import jax_utils
-from flax import linen as nn
-from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
import optax
+from absl import logging
+from flax import jax_utils
+from flax import linen as nn
+from flax.training import common_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.wmt import bleu
-from algoperf.workloads.wmt.wmt_jax import decode
-from algoperf.workloads.wmt.wmt_jax import models
+from algoperf.workloads.wmt.wmt_jax import decode, models
from algoperf.workloads.wmt.workload import BaseWmtWorkload
@@ -31,11 +29,12 @@ class WmtWorkload(BaseWmtWorkload):
"""WMT Jax workload."""
def compute_weighted_cross_entropy(
- self,
- logits: spec.Tensor,
- targets: spec.Tensor,
- weights: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ logits: spec.Tensor,
+ targets: spec.Tensor,
+ weights: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.1,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
@@ -50,76 +49,86 @@ def compute_weighted_cross_entropy(
valid examples in batch, 'per_example': 1-d array of per-example losses}
"""
if logits.ndim != targets.ndim + 1:
- raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and '
- f'{targets.shape} targets.')
+ raise ValueError(
+ f'Incorrect shapes. Got shape {logits.shape} logits and '
+ f'{targets.shape} targets.'
+ )
smoothed_targets = optax.smooth_labels(
- common_utils.onehot(targets, self._vocab_size), label_smoothing)
+ common_utils.onehot(targets, self._vocab_size), label_smoothing
+ )
per_example_losses = -jnp.sum(
- smoothed_targets * nn.log_softmax(logits), axis=-1)
+ smoothed_targets * nn.log_softmax(logits), axis=-1
+ )
if weights is None:
weights = jnp.ones_like(targets)
- per_example_losses = jnp.where(weights, per_example_losses, 0.)
+ per_example_losses = jnp.where(weights, per_example_losses, 0.0)
summed_loss = per_example_losses.sum()
n_valid_examples = weights.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@functools.partial(
- jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
+ jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)
+ )
def eval_step_pmapped(
- self, params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> Dict[str, spec.Tensor]:
"""Calculate evaluation metrics on a batch."""
inputs = batch['inputs']
targets = batch['targets']
weights = batch['weights']
logits = self._eval_model.apply({'params': params}, inputs, targets)
- summed_loss = self.compute_weighted_cross_entropy(logits,
- targets,
- weights,
- 0.0)['summed']
+ summed_loss = self.compute_weighted_cross_entropy(
+ logits, targets, weights, 0.0
+ )['summed']
acc_sum, weight_sum = self.compute_weighted_accuracy(
- logits, targets, weights)
+ logits, targets, weights
+ )
return {
- 'loss': summed_loss,
- 'accuracy': acc_sum,
- 'denominator': weight_sum,
+ 'loss': summed_loss,
+ 'accuracy': acc_sum,
+ 'denominator': weight_sum,
}
- def eval_step(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ def eval_step(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> Dict[str, spec.Tensor]:
replicated_eval_metrics = self.eval_step_pmapped(params, batch)
return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics)
@functools.partial(
- jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
- def initialize_cache(self,
- inputs: spec.Tensor,
- max_decode_len: int = 256) -> Dict[str, spec.Tensor]:
+ jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)
+ )
+ def initialize_cache(
+ self, inputs: spec.Tensor, max_decode_len: int = 256
+ ) -> Dict[str, spec.Tensor]:
"""Initialize a cache for a given input shape and max decode length."""
config = models.TransformerConfig(deterministic=True, decode=True)
target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
initial_variables = models.Transformer(config).init(
- jax.random.PRNGKey(0),
- jnp.ones(inputs.shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
+ jax.random.PRNGKey(0),
+ jnp.ones(inputs.shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ )
return initial_variables['cache']
# eos_id, max_decode_len are constant.
@functools.partial(
- jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5))
- def predict_step(self,
- inputs: spec.Tensor,
- params: spec.ParameterContainer,
- cache: Dict[str, spec.Tensor],
- eos_id: int,
- max_decode_len: int,
- beam_size: int = 4) -> spec.Tensor:
+ jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)
+ )
+ def predict_step(
+ self,
+ inputs: spec.Tensor,
+ params: spec.ParameterContainer,
+ cache: Dict[str, spec.Tensor],
+ eos_id: int,
+ max_decode_len: int,
+ beam_size: int = 4,
+ ) -> spec.Tensor:
"""Predict translation with fast decoding beam search on a batch."""
config = replace(self._eval_model.config, decode=True)
# Prepare transformer fast-decoder call for beam search: for beam search, we
@@ -129,27 +138,29 @@ def predict_step(self,
# i.e. if we denote each batch element subtensor as el[n]:
# [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
encoded_inputs = decode.flat_batch_beam_expand(
- models.Transformer(config).apply({'params': params},
- inputs,
- method=models.Transformer.encode),
- beam_size)
+ models.Transformer(config).apply(
+ {'params': params}, inputs, method=models.Transformer.encode
+ ),
+ beam_size,
+ )
raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size)
def tokens_ids_to_logits(
- flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
+ flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]:
"""Token slice to logits from decoder model."""
# --> [batch * beam, 1, vocab]
flat_logits, new_vars = models.Transformer(config).apply(
- {
- 'params': params,
- 'cache': flat_cache,
- },
- encoded_inputs,
- raw_inputs, # only needed for input padding mask
- flat_ids,
- mutable=['cache'],
- method=models.Transformer.decode)
+ {
+ 'params': params,
+ 'cache': flat_cache,
+ },
+ encoded_inputs,
+ raw_inputs, # only needed for input padding mask
+ flat_ids,
+ mutable=['cache'],
+ method=models.Transformer.decode,
+ )
new_flat_cache = new_vars['cache']
# Remove singleton sequence-length dimension:
# [batch * beam, 1, vocab] --> [batch * beam, vocab]
@@ -159,35 +170,36 @@ def tokens_ids_to_logits(
# Using the above-defined single-step decoder function, run a
# beam search over possible sequences given input encoding.
beam_seqs, _ = decode.beam_search(
- inputs,
- cache,
- tokens_ids_to_logits,
- beam_size=beam_size,
- alpha=0.6,
- eos_id=eos_id,
- max_decode_len=max_decode_len)
+ inputs,
+ cache,
+ tokens_ids_to_logits,
+ beam_size=beam_size,
+ alpha=0.6,
+ eos_id=eos_id,
+ max_decode_len=max_decode_len,
+ )
# Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
# sorted in increasing order of log-probability.
# Return the highest scoring beam sequence, drop first dummy 0 token.
return beam_seqs[:, -1, 1:]
- def translate_and_calculate_bleu(self,
- params: spec.ParameterContainer,
- ds_iter: Iterator,
- num_batches: int,
- max_predict_length: int) -> spec.Tensor:
+ def translate_and_calculate_bleu(
+ self,
+ params: spec.ParameterContainer,
+ ds_iter: Iterator,
+ num_batches: int,
+ max_predict_length: int,
+ ) -> spec.Tensor:
"""Translates the `predict_ds` and calculates the BLEU score."""
logging.info('Translating evaluation dataset.')
references, predictions = [], []
for _ in range(num_batches):
pred_batch = next(ds_iter)
cache = self.initialize_cache(pred_batch['inputs'])
- predicted = self.predict_step(pred_batch['inputs'],
- params,
- cache,
- decode.EOS_ID,
- max_predict_length)
+ predicted = self.predict_step(
+ pred_batch['inputs'], params, cache, decode.EOS_ID, max_predict_length
+ )
predicted = _to_host(predicted)
targets = _to_host(pred_batch['targets'])
# Find actual batch size, ignoring the potential padding.
@@ -206,13 +218,7 @@ def translate_and_calculate_bleu(self,
bleu_score = bleu.corpus_bleu(predictions, [references]).score
return bleu_score
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is used as attention_dropout_rate."""
-
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_fake_batch_size = 2
input_shape = (init_fake_batch_size, 256)
target_shape = (init_fake_batch_size, 256)
@@ -225,20 +231,20 @@ def init_model_fn(
raise ValueError(f'Unknown activation function {self.activation}.')
model_config = models.TransformerConfig(
- dropout_rate=dropout_rate,
- attention_dropout_rate=aux_dropout_rate,
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu,
+ )
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
- params_rng, dropout_rng = jax.random.split(rng)
- initial_variables = jax.jit(
- self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
+ params_rng, _ = jax.random.split(rng)
+ initial_variables = jax.jit(self._eval_model.init)(
+ {'params': params_rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ )
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
@@ -249,45 +255,54 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'shared_embedding'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: [float] = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch.get('inputs', None)
targets = augmented_and_preprocessed_input_batch.get('targets', None)
inputs_positions = augmented_and_preprocessed_input_batch.get(
- 'inputs_position', None)
+ 'inputs_position', None
+ )
targets_positions = augmented_and_preprocessed_input_batch.get(
- 'targets_position', None)
+ 'targets_position', None
+ )
inputs_segmentations = augmented_and_preprocessed_input_batch.get(
- 'inputs_segmentation', None)
+ 'inputs_segmentation', None
+ )
targets_segmentations = augmented_and_preprocessed_input_batch.get(
- 'targets_segmentation', None)
+ 'targets_segmentation', None
+ )
if mode == spec.ForwardPassMode.TRAIN:
model = self._train_model
else:
model = self._eval_model
- logits_batch = model.apply({'params': params},
- inputs,
- targets,
- inputs_positions=inputs_positions,
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentations,
- targets_segmentation=targets_segmentations,
- rngs={'dropout': rng})
+ logits_batch = model.apply(
+ {'params': params},
+ inputs,
+ targets,
+ inputs_positions=inputs_positions,
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentations,
+ targets_segmentation=targets_segmentations,
+ rngs={'dropout': rng},
+ dropout_rate=dropout_rate,
+ )
return logits_batch, None
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
eval_denominator = total_metrics.pop('denominator')
diff --git a/algoperf/workloads/wmt/wmt_pytorch/decode.py b/algoperf/workloads/wmt/wmt_pytorch/decode.py
index 26ff36650..7974412d7 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/decode.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/decode.py
@@ -21,8 +21,9 @@
NEG_INF = torch.tensor(-1.0e7, device=DEVICE)
-def brevity_penalty(alpha: float, length: Union[int,
- torch.Tensor]) -> torch.Tensor:
+def brevity_penalty(
+ alpha: float, length: Union[int, torch.Tensor]
+) -> torch.Tensor:
"""Brevity penalty function for beam search penalizing short sequences.
Args:
@@ -57,8 +58,9 @@ def flatten_beam_dim(x: torch.Tensor) -> torch.Tensor:
return x.view(-1, *x.shape[2:])
-def unflatten_beam_dim(x: torch.Tensor, batch_size: int,
- beam_size: int) -> torch.Tensor:
+def unflatten_beam_dim(
+ x: torch.Tensor, batch_size: int, beam_size: int
+) -> torch.Tensor:
"""Unflattens the first, flat batch*beam dimension of a non-scalar tensor."""
if x.dim() < 2: # ignore scalars (e.g. cache index)
return x
@@ -71,10 +73,12 @@ def flat_batch_beam_expand(x: torch.Tensor, beam_size: int) -> torch.Tensor:
return flatten_beam_dim(add_beam_dim(x, beam_size))
-def gather_beams(nested: Dict[str, Any],
- beam_indices: torch.Tensor,
- batch_size: int,
- new_beam_size: int) -> Dict[str, Any]:
+def gather_beams(
+ nested: Dict[str, Any],
+ beam_indices: torch.Tensor,
+ batch_size: int,
+ new_beam_size: int,
+) -> Dict[str, Any]:
"""Gathers the beam slices indexed by beam_indices into new beam tensor.
Args:
@@ -88,10 +92,13 @@ def gather_beams(nested: Dict[str, Any],
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
batch_indices = torch.reshape(
- torch.div(
- torch.arange(batch_size * new_beam_size, device=DEVICE),
- new_beam_size,
- rounding_mode='floor'), (batch_size, new_beam_size))
+ torch.div(
+ torch.arange(batch_size * new_beam_size, device=DEVICE),
+ new_beam_size,
+ rounding_mode='floor',
+ ),
+ (batch_size, new_beam_size),
+ )
def gather_fn(x):
if x.dim() < 2: # ignore scalars (e.g. cache index)
@@ -101,10 +108,12 @@ def gather_fn(x):
return jax.tree.map(gather_fn, nested)
-def gather_topk_beams(nested: Dict[str, Any],
- score_or_log_prob: torch.Tensor,
- batch_size: int,
- new_beam_size: int) -> Dict[str, Any]:
+def gather_topk_beams(
+ nested: Dict[str, Any],
+ score_or_log_prob: torch.Tensor,
+ batch_size: int,
+ new_beam_size: int,
+) -> Dict[str, Any]:
"""Gathers the top-k beam slices given by score_or_log_prob array.
Args:
@@ -129,6 +138,7 @@ def gather_topk_beams(nested: Dict[str, Any],
@dataclass
class BeamState:
"""Holds beam search state data."""
+
# The position of the decoding loop in the length dimension.
cur_index: torch.Tensor # scalar int32: current decoded length index.
# The active sequence log probabilities and finished sequence scores.
@@ -143,49 +153,52 @@ class BeamState:
cache: Dict[str, Any] # Any dict (of dicts), with torch.Tensors as leafs.
-def beam_init(batch_size: int,
- beam_size: int,
- max_decode_len: int,
- cache: Dict[str, Any]) -> BeamState:
+def beam_init(
+ batch_size: int, beam_size: int, max_decode_len: int, cache: Dict[str, Any]
+) -> BeamState:
"""Initializes the beam search state data structure."""
cur_index0 = torch.tensor(0, device=DEVICE)
live_logprobs0 = torch.tile(
- torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE),
- [batch_size, 1])
+ torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE),
+ [batch_size, 1],
+ )
finished_scores0 = (
- torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF)
- live_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len),
- dtype=torch.int32,
- device=DEVICE)
- finished_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len),
- dtype=torch.int32,
- device=DEVICE)
- finished_flags0 = torch.zeros((batch_size, beam_size),
- dtype=torch.bool,
- device=DEVICE)
+ torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF
+ )
+ live_seqs0 = torch.zeros(
+ (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE
+ )
+ finished_seqs0 = torch.zeros(
+ (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE
+ )
+ finished_flags0 = torch.zeros(
+ (batch_size, beam_size), dtype=torch.bool, device=DEVICE
+ )
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
- cur_index=cur_index0,
- live_logprobs=live_logprobs0,
- finished_scores=finished_scores0,
- live_seqs=live_seqs0,
- finished_seqs=finished_seqs0,
- finished_flags=finished_flags0,
- cache=beam_cache0)
+ cur_index=cur_index0,
+ live_logprobs=live_logprobs0,
+ finished_scores=finished_scores0,
+ live_seqs=live_seqs0,
+ finished_seqs=finished_seqs0,
+ finished_flags=finished_flags0,
+ cache=beam_cache0,
+ )
# Beam search routine:
def beam_search(
- inputs: torch.Tensor,
- cache: Optional[Dict[str, Any]],
- tokens_to_logits: Callable,
- beam_size: int = 4,
- alpha: float = 0.6,
- eos_id: int = EOS_ID,
- max_decode_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ inputs: torch.Tensor,
+ cache: Optional[Dict[str, Any]],
+ tokens_to_logits: Callable,
+ beam_size: int = 4,
+ alpha: float = 0.6,
+ eos_id: int = EOS_ID,
+ max_decode_len: Optional[int] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
"""Beam search for transformer machine translation.
Args:
@@ -211,10 +224,9 @@ def beam_search(
end_marker = torch.tensor(eos_id, device=DEVICE)
# initialize beam search state
- beam_search_init_state = beam_init(batch_size,
- beam_size,
- max_decode_len,
- cache)
+ beam_search_init_state = beam_init(
+ batch_size, beam_size, max_decode_len, cache
+ )
def beam_search_loop_cond_fn(state: BeamState) -> bool:
"""Beam search loop termination condition."""
@@ -227,11 +239,12 @@ def beam_search_loop_cond_fn(state: BeamState) -> bool:
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
# Get the worst scores from finished sequences.
worst_finished_scores, _ = torch.min(
- state.finished_scores, dim=1, keepdim=True)
+ state.finished_scores, dim=1, keepdim=True
+ )
# Mask out scores from slots without any actual finished sequences.
- worst_finished_scores = torch.where(state.finished_flags,
- worst_finished_scores,
- NEG_INF)
+ worst_finished_scores = torch.where(
+ state.finished_flags, worst_finished_scores, NEG_INF
+ )
# If no best possible live score is better than current worst finished
# scores, the search cannot improve the finished set further.
search_terminated = torch.all(worst_finished_scores > best_live_scores)
@@ -248,7 +261,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# --> [batch * beam, 1]
cur_index = state.cur_index
flat_ids = flatten_beam_dim(
- state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1])
+ state.live_seqs[:batch_size, :beam_size, cur_index : cur_index + 1]
+ )
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)
@@ -263,7 +277,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree.map(
- lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)
+ lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache
+ )
# Gather log probabilities from logits
candidate_log_probs = F.log_softmax(logits, dim=-1)
@@ -287,13 +302,13 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
topk_log_probs, topk_indices = torch.topk(flat_log_probs, k=beams_to_keep)
# Recover the beam index by floor division.
topk_beam_indices = torch.div(
- topk_indices, vocab_size, rounding_mode='floor')
+ topk_indices, vocab_size, rounding_mode='floor'
+ )
# Gather 2*k top beams.
# --> [batch, 2*beams, length]
- topk_seq = gather_beams(state.live_seqs,
- topk_beam_indices,
- batch_size,
- beams_to_keep)
+ topk_seq = gather_beams(
+ state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
+ )
# Append the most probable 2*K token IDs to the top 2*K sequences
# Recover token id by modulo division and expand Id array for broadcasting.
@@ -301,11 +316,11 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
topk_ids = torch.unsqueeze(topk_indices % vocab_size, dim=2)
# Update sequences for the 2*K top-k new sequences.
# --> [batch, 2*beams, length]
- topk_seq[:, :, cur_index + 1:] = topk_ids
+ topk_seq[:, :, cur_index + 1 :] = topk_ids
# Update LIVE (in-progress) sequences:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
- newly_finished = (topk_seq[:, :, cur_index + 1] == end_marker)
+ newly_finished = topk_seq[:, :, cur_index + 1] == end_marker
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
@@ -316,22 +331,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
new_topk_indices = torch.flip(new_topk_indices, (1,))
# Gather the top k beams (from top 2*k beams).
# --> [batch, beams, length], [batch, beams]
- top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs],
- new_topk_indices,
- batch_size, beam_size)
+ top_alive_seq, top_alive_log_probs = gather_beams(
+ [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
+ )
# Determine the top k beam indices from the original set of all beams.
# --> [batch, beams]
- top_alive_indices = gather_beams(topk_beam_indices,
- new_topk_indices,
- batch_size,
- beam_size)
+ top_alive_indices = gather_beams(
+ topk_beam_indices, new_topk_indices, batch_size, beam_size
+ )
# With these, gather the top k beam-associated caches.
# --> {[batch, beams, ...], ...}
- top_alive_cache = gather_beams(new_cache,
- top_alive_indices,
- batch_size,
- beam_size)
+ top_alive_cache = gather_beams(
+ new_cache, top_alive_indices, batch_size, beam_size
+ )
# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
@@ -344,24 +357,33 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams.
finished_seqs = torch.cat( # --> [batch, 3*beams, length]
- [state.finished_seqs, topk_seq], dim=1)
+ [state.finished_seqs, topk_seq], dim=1
+ )
finished_scores = torch.cat( # --> [batch, 3*beams]
- [state.finished_scores, new_scores], dim=1)
+ [state.finished_scores, new_scores], dim=1
+ )
finished_flags = torch.cat( # --> [batch, 3*beams]
- [state.finished_flags, newly_finished], dim=1)
+ [state.finished_flags, newly_finished], dim=1
+ )
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
- gather_topk_beams([finished_seqs, finished_scores, finished_flags],
- finished_scores, batch_size, beam_size))
+ gather_topk_beams(
+ [finished_seqs, finished_scores, finished_flags],
+ finished_scores,
+ batch_size,
+ beam_size,
+ )
+ )
return BeamState(
- cur_index=cur_index + 1,
- live_logprobs=top_alive_log_probs,
- finished_scores=top_finished_scores,
- live_seqs=top_alive_seq,
- finished_seqs=top_finished_seq,
- finished_flags=top_finished_flags,
- cache=top_alive_cache)
+ cur_index=cur_index + 1,
+ live_logprobs=top_alive_log_probs,
+ finished_scores=top_finished_scores,
+ live_seqs=top_alive_seq,
+ finished_seqs=top_finished_seq,
+ finished_flags=top_finished_flags,
+ cache=top_alive_cache,
+ )
state = beam_search_init_state
while beam_search_loop_cond_fn(state):
@@ -373,12 +395,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# --> [batch]
none_finished = torch.any(final_state.finished_flags, dim=1)
# --> [batch, beams, length]
- finished_seqs = torch.where(none_finished[:, None, None],
- final_state.finished_seqs,
- final_state.live_seqs)
+ finished_seqs = torch.where(
+ none_finished[:, None, None],
+ final_state.finished_seqs,
+ final_state.live_seqs,
+ )
# --> [batch, beams]
- finished_scores = torch.where(none_finished[:, None],
- final_state.finished_scores,
- final_state.live_logprobs)
+ finished_scores = torch.where(
+ none_finished[:, None],
+ final_state.finished_scores,
+ final_state.live_logprobs,
+ )
return finished_seqs, finished_scores
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py
index a1c7ce15e..430cc945b 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models.py
@@ -3,11 +3,11 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
-from torch import nn
-from torch import Tensor
import torch.nn.functional as F
-from torch.nn.init import normal_
-from torch.nn.init import xavier_uniform_
+from torch import Tensor, nn
+from torch.nn.init import normal_, xavier_uniform_
+
+DROPOUT_RATE = 0.1
def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
@@ -21,7 +21,8 @@ def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
A `[batch..., len, len]` shaped causal attention mask.
"""
idxs = torch.broadcast_to(
- torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape)
+ torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape
+ )
return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2))
@@ -31,55 +32,60 @@ def make_src_mask(src, inputs_segmentation, nhead):
# Add segmentation block-diagonal attention mask if using segmented data.
if inputs_segmentation is not None:
src_mask = torch.logical_and(
- src_mask,
- torch.eq(
- inputs_segmentation.unsqueeze(-1),
- inputs_segmentation.unsqueeze(-2)))
+ src_mask,
+ torch.eq(
+ inputs_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2)
+ ),
+ )
# Flip values and ensure numerical stability.
src_mask = torch.repeat_interleave(
- torch.logical_not(src_mask), repeats=nhead, dim=0)
+ torch.logical_not(src_mask), repeats=nhead, dim=0
+ )
new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32)
new_src_mask.masked_fill_(src_mask, -1e10)
return new_src_mask
-def make_tgt_and_memory_mask(tgt,
- src,
- inputs_segmentation,
- targets_segmentation,
- decode,
- nhead):
- """ Utility for creating target and memory mask and adjust them for PyTorch
+def make_tgt_and_memory_mask(
+ tgt, src, inputs_segmentation, targets_segmentation, decode, nhead
+):
+ """Utility for creating target and memory mask and adjust them for PyTorch
Transformer API."""
if not decode:
tgt_mask = torch.logical_and(
- torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)),
- make_causal_mask(tgt, device=tgt.device))
+ torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)),
+ make_causal_mask(tgt, device=tgt.device),
+ )
memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2))
else:
tgt_mask = None
- memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1),
- (src > 0).unsqueeze(-2))
+ memory_mask = torch.mul(
+ (torch.ones_like(tgt) > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)
+ )
# Add segmentation block-diagonal attention masks if using segmented data.
if inputs_segmentation is not None:
tgt_mask = torch.logical_and(
- tgt_mask,
- torch.eq(
- targets_segmentation.unsqueeze(-1),
- targets_segmentation.unsqueeze(-2)))
+ tgt_mask,
+ torch.eq(
+ targets_segmentation.unsqueeze(-1), targets_segmentation.unsqueeze(-2)
+ ),
+ )
memory_mask = torch.logical_and(
- memory_mask,
- torch.eq(
- targets_segmentation.unsqueeze(-1),
- inputs_segmentation.unsqueeze(-2)))
+ memory_mask,
+ torch.eq(
+ targets_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2)
+ ),
+ )
# Flip values and ensure numerical stability.
memory_mask = torch.repeat_interleave(
- torch.logical_not(memory_mask), repeats=nhead, dim=0)
+ torch.logical_not(memory_mask), repeats=nhead, dim=0
+ )
new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32)
new_memory_mask.masked_fill_(memory_mask, -1e10)
if tgt_mask is not None:
tgt_mask = torch.repeat_interleave(
- torch.logical_not(tgt_mask), repeats=nhead, dim=0)
+ torch.logical_not(tgt_mask), repeats=nhead, dim=0
+ )
new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32)
new_tgt_mask.masked_fill_(tgt_mask, -1e10)
tgt_mask = new_tgt_mask
@@ -98,48 +104,44 @@ def shift_right(x, axis=1):
class Transformer(nn.Module):
"""Transformer architecture based on the model from the WMT Jax workload."""
- def __init__(self,
- ntoken: int = 32000,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- dropout_rate: Optional[float] = 0.1,
- attention_dropout_rate: Optional[float] = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
+ def __init__(
+ self,
+ ntoken: int = 32000,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ ):
super().__init__()
- if dropout_rate is None:
- dropout_rate = 0.1
- if attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- self.pos_encoder = PositionalEncoding(d_model, dropout_rate)
+ self.pos_encoder = PositionalEncoding(d_model)
self.shared_embedding = nn.Embedding(ntoken, d_model)
- self.encoder = Encoder(d_model,
- nhead,
- d_hid,
- nlayers,
- dropout_rate,
- attention_dropout_rate,
- activation,
- glu,
- layer_norm_eps,
- attention_temp,
- pre_ln)
- self.decoder = Decoder(d_model,
- nhead,
- d_hid,
- nlayers,
- dropout_rate,
- attention_dropout_rate,
- activation,
- glu,
- layer_norm_eps,
- attention_temp,
- pre_ln)
+ self.encoder = Encoder(
+ d_model,
+ nhead,
+ d_hid,
+ nlayers,
+ activation,
+ glu,
+ layer_norm_eps,
+ attention_temp,
+ pre_ln,
+ )
+ self.decoder = Decoder(
+ d_model,
+ nhead,
+ d_hid,
+ nlayers,
+ activation,
+ glu,
+ layer_norm_eps,
+ attention_temp,
+ pre_ln,
+ )
# Share positional encoding and embedding between encoder and decoder.
self.encoder.pos_encoder = self.pos_encoder
self.encoder.shared_embedding = self.shared_embedding
@@ -156,14 +158,17 @@ def _reset_parameters(self):
if module.bias is not None:
normal_(module.bias, std=1e-6)
- def forward(self,
- src: Tensor,
- tgt: Tensor,
- inputs_positions: Optional[Tensor] = None,
- targets_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- targets_segmentation: Optional[Tensor] = None,
- decode: bool = False) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ tgt: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ targets_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ targets_segmentation: Optional[Tensor] = None,
+ decode: bool = False,
+ dropout_rate: float = DROPOUT_RATE,
+ ) -> Tensor:
"""
Args:
src: Tensor, shape [batch_size, seq_len]
@@ -173,24 +178,30 @@ def forward(self,
inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len]
targets_segmentation: Optional[Tensor], shape [batch_size, seq_len]
decode: bool
+ dropout_rate: float
Returns:
output Tensor of shape [batch_size, seq_len, ntoken]
"""
if src.size(0) != tgt.size(0):
raise RuntimeError('The batch size of src and tgt must be equal.')
+
memory = self.encoder(
- src,
- inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation)
+ src,
+ inputs_positions=inputs_positions,
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate,
+ )
output = self.decoder(
- tgt,
- memory,
- src, # just for calculating the padding mask
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation,
- decode=decode)
+ tgt,
+ memory,
+ src, # just for calculating the padding mask
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentation,
+ targets_segmentation=targets_segmentation,
+ decode=decode,
+ dropout_rate=dropout_rate,
+ )
return output
@@ -213,28 +224,38 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
+
__constants__ = ['norm']
- def __init__(self,
- encoder_layer,
- num_layers,
- norm=None,
- enable_nested_tensor=True,
- mask_check=True):
+ def __init__(
+ self,
+ encoder_layer,
+ num_layers,
+ norm=None,
+ enable_nested_tensor=True,
+ mask_check=True,
+ ):
super().__init__()
self.layers = nn.ModuleList(
- [copy.deepcopy(encoder_layer) for _ in range(num_layers)])
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
+ )
self.num_layers = num_layers
self.norm = norm
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
- def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
+ dropout_rate: the dropout probability (optional).
Shape:
see the docs in Transformer class.
@@ -243,10 +264,10 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
convert_to_nested = False
for mod in self.layers:
- output = mod(output, src_mask=mask)
+ output = mod(output, src_mask=mask, dropout_rate=dropout_rate)
if convert_to_nested:
- output = output.to_padded_tensor(0.)
+ output = output.to_padded_tensor(0.0)
if self.norm is not None:
output = self.norm(output)
@@ -255,109 +276,118 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
class Encoder(nn.Module):
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ ):
super().__init__()
self.nhead = nhead
self.shared_embedding = None
self.pos_encoder = None
encoder_layer = TransformerEncoderLayer(
- d_model,
- nhead,
- d_hid,
- dropout_rate,
- attention_dropout_rate=attention_dropout_rate,
- activation=activation,
- glu=glu,
- layer_norm_eps=layer_norm_eps,
- attention_temp=attention_temp,
- pre_ln=pre_ln)
- encoder_norm = (
- nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
+ d_model,
+ nhead,
+ d_hid,
+ activation=activation,
+ glu=glu,
+ layer_norm_eps=layer_norm_eps,
+ attention_temp=attention_temp,
+ pre_ln=pre_ln,
+ )
+ encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None
self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)
- def forward(self,
- src: Tensor,
- inputs_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
src = src.to(torch.int)
src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
src = self.shared_embedding(src)
- src = self.pos_encoder(src, inputs_positions)
- memory = self.encoder(src, mask=src_mask)
+ src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate)
+ memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate)
return memory
class Decoder(nn.Module):
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ ):
super().__init__()
self.nhead = nhead
self.shared_embedding = None
self.pos_encoder = None
- self.decoder = TransformerDecoder(d_model,
- nhead,
- d_hid,
- dropout_rate,
- attention_dropout_rate,
- activation,
- glu,
- layer_norm_eps,
- nlayers,
- attention_temp,
- pre_ln)
+ self.decoder = TransformerDecoder(
+ d_model,
+ nhead,
+ d_hid,
+ activation,
+ glu,
+ layer_norm_eps,
+ nlayers,
+ attention_temp,
+ pre_ln,
+ )
def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- src: Tensor, # just for calculating the padding mask
- targets_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- targets_segmentation: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None) -> Any:
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ src: Tensor, # just for calculating the padding mask
+ targets_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ targets_segmentation: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
tgt = tgt.to(torch.int)
tgt_mask, memory_mask = make_tgt_and_memory_mask(
- tgt, src, inputs_segmentation, targets_segmentation,
- decode, self.nhead)
+ tgt, src, inputs_segmentation, targets_segmentation, decode, self.nhead
+ )
if not decode:
tgt = shift_right(tgt)
tgt = self.shared_embedding(tgt)
- tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache)
+ tgt = self.pos_encoder(
+ tgt,
+ targets_positions,
+ decode=decode,
+ cache=cache,
+ dropout_rate=dropout_rate,
+ )
if decode:
tgt, cache = tgt
output = self.decoder(
- tgt,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- decode=decode,
- max_len=max_len,
- cache=cache)
+ tgt,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ dropout_rate=dropout_rate,
+ )
if decode:
output, cache = output
normalize = math.sqrt(output.shape[-1])
@@ -368,28 +398,24 @@ def forward(
class PositionalEncoding(nn.Module):
-
- def __init__(self,
- d_model: int,
- dropout_rate: float = 0.1,
- max_len: int = 256):
+ def __init__(self, d_model: int, max_len: int = 256):
super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
position = torch.arange(max_len).unsqueeze(1)
scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
div_term = torch.exp(torch.arange(d_model // 2) * scale_factor)
pe = torch.zeros(1, max_len, d_model)
- pe[0, :, :d_model // 2] = torch.sin(position * div_term)
- pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term)
+ pe[0, :, : d_model // 2] = torch.sin(position * div_term)
+ pe[0, :, d_model // 2 : 2 * (d_model // 2)] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(
- self,
- x: Tensor,
- inputs_positions: Optional[Tensor] = None,
- decode: bool = False,
- cache: Optional[Dict[str, Dict[str, Tensor]]] = None
+ self,
+ x: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ decode: bool = False,
+ cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
+ dropout_rate: Optional[float] = 0.0,
) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
"""
Args:
@@ -397,6 +423,7 @@ def forward(
inputs_positions: Tensor (shape [batch_size, seq_len]) or None
decode: bool
cache: Dict[str, Dict[str, Tensor]] or None
+ dropout_rate: Optional[float]
Returns:
Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]]
"""
@@ -405,21 +432,22 @@ def forward(
name = self._get_name()
if cache is None:
cache = {
- name: {
- 'cache_index':
- torch.tensor(0, dtype=torch.long, device=self.pe.device),
- },
+ name: {
+ 'cache_index': torch.tensor(
+ 0, dtype=torch.long, device=self.pe.device
+ ),
+ },
}
pe = self.pe[0, cache[name]['cache_index'], :]
cache[name]['cache_index'] += 1
- return self.dropout(x + pe), cache
+ return F.dropout(x + pe, dropout_rate, self.training), cache
if inputs_positions is None:
# normal unpacked case:
- pe = self.pe[:, :x.size(1), :]
+ pe = self.pe[:, : x.size(1), :]
else:
# for packed data we need to use known position indices:
pe = self.pe[0, inputs_positions, :]
- return self.dropout(x + pe)
+ return F.dropout(x + pe, dropout_rate, self.training)
# TransformerEncoderLayer and TransformerDecoderLayer are taken from:
@@ -438,7 +466,6 @@ class TransformerEncoderLayer(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16).
dim_feedforward: the dimension of the feedforward network model
(default=1024).
- dropout_rate: the dropout_rate value (default=0.1).
activation: the activation function of the intermediate layer, can be a
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
@@ -451,81 +478,91 @@ class TransformerEncoderLayer(nn.Module):
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
"""
+
__constants__ = ['pre_ln']
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- dim_feedforward: int = 1024,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True,
- device=None,
- dtype=None) -> None:
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ dim_feedforward: int = 1024,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=True,
- dropout_rate=attention_dropout_rate,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
+ d_model,
+ nhead,
+ self_attn=True,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs,
+ )
# Implementation of Feedforward model.
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.glu = glu
if self.glu:
self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout_rate)
- self.dropout2 = nn.Dropout(dropout_rate)
self.activation = activation
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
-
+ dropout_rate: the dropout probability value (optional).
Shape:
see the docs in Transformer class.
"""
x = src
if self.pre_ln:
- x = x + self._sa_block(self.norm1(x), src_mask)
- x = x + self._ff_block(self.norm2(x))
+ x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate)
+ x = x + self._ff_block(self.norm2(x), dropout_rate)
else:
- x = self.norm1(x + self._sa_block(x, src_mask))
- x = self.norm2(x + self._ff_block(x))
+ x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate))
+ x = self.norm2(x + self._ff_block(x, dropout_rate))
return x
# Self-attention block:
- def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor:
- x, _ = self.self_attn(x, attn_mask=attn_mask)
- return self.dropout1(x)
+ def _sa_block(
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
+ x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, training=self.training)
# Feed forward block:
- def _ff_block(self, inputs: Tensor) -> Tensor:
+ def _ff_block(
+ self, inputs: Tensor, dropout_rate: Optional[float] = 0.0
+ ) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
x = x * y
- x = self.linear2(self.dropout(x))
- return self.dropout2(x)
+ x = self.linear2(F.dropout(x, dropout_rate, training=self.training))
+ return F.dropout(x, dropout_rate, training=self.training)
# Modified to use cache for autoregressive decoding and custom
@@ -537,7 +574,6 @@ class TransformerDecoder(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16)
d_hid: the dimension of the feedforward network model
(default=1024)
- dropout_rate: the dropout_rate value (default=0.1)
layer_norm_eps: the eps value in layer normalization components
(default=1e-6).
decoder_layer: an instance of the TransformerDecoderLayer() class
@@ -549,45 +585,51 @@ class TransformerDecoder(nn.Module):
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
"""
+
__constants__ = ['norm']
- def __init__(self,
- d_model,
- nhead,
- d_hid,
- dropout_rate,
- attention_dropout_rate,
- activation,
- glu,
- layer_norm_eps,
- num_layers,
- attention_temp,
- pre_ln):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ d_hid,
+ activation,
+ glu,
+ layer_norm_eps,
+ num_layers,
+ attention_temp,
+ pre_ln,
+ ):
super().__init__()
- self.layers = nn.ModuleList([
+ self.layers = nn.ModuleList(
+ [
TransformerDecoderLayer(
- d_model,
- nhead,
- d_hid,
- dropout_rate,
- attention_dropout_rate,
- activation,
- glu,
- layer_norm_eps=layer_norm_eps,
- attention_temp=attention_temp,
- pre_ln=pre_ln) for _ in range(num_layers)
- ])
+ d_model,
+ nhead,
+ d_hid,
+ activation,
+ glu,
+ layer_norm_eps=layer_norm_eps,
+ attention_temp=attention_temp,
+ pre_ln=pre_ln,
+ )
+ for _ in range(num_layers)
+ ]
+ )
self.num_layers = num_layers
- self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
-
- def forward(self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None) -> Any:
+ self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
@@ -596,6 +638,7 @@ def forward(self,
memory_mask: the mask for the memory sequence (optional).
decode: whether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
+ dropout_rate: the dropout probability value (optional)
Shape:
see the docs in Transformer class.
"""
@@ -603,14 +646,16 @@ def forward(self,
for idx, mod in enumerate(self.layers):
output, cache = mod(
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=idx)
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=idx,
+ dropout_rate=dropout_rate,
+ )
if self.norm is not None:
output = self.norm(output)
@@ -636,7 +681,6 @@ class TransformerDecoderLayer(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16).
dim_feedforward: the dimension of the feedforward network model
(default=1024).
- dropout_rate: the dropout_rate value (default=0.1).
activation: the activation function of the intermediate layer, can be a
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
@@ -650,70 +694,69 @@ class TransformerDecoderLayer(nn.Module):
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)
"""
+
__constants__ = ['pre_ln']
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- dim_feedforward: int = 1024,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- pre_ln: bool = True,
- attention_temp: float = 1.0,
- device=None,
- dtype=None) -> None:
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ dim_feedforward: int = 1024,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ pre_ln: bool = True,
+ attention_temp: float = 1.0,
+ device=None,
+ dtype=None,
+ ) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=True,
- dropout_rate=attention_dropout_rate,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
+ d_model,
+ nhead,
+ self_attn=True,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs,
+ )
self.multihead_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=False,
- dropout_rate=attention_dropout_rate,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
+ d_model,
+ nhead,
+ self_attn=False,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs,
+ )
# Implementation of Feedforward model.
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.glu = glu
if self.glu:
- self.linear_glu = nn.Linear(dim_feedforward,
- dim_feedforward,
- **factory_kwargs)
- self.dropout = nn.Dropout(dropout_rate)
+ self.linear_glu = nn.Linear(
+ dim_feedforward, dim_feedforward, **factory_kwargs
+ )
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout_rate)
- self.dropout2 = nn.Dropout(dropout_rate)
- self.dropout3 = nn.Dropout(dropout_rate)
self.activation = activation
def forward( # pylint: disable=arguments-renamed
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
@@ -722,6 +765,7 @@ def forward( # pylint: disable=arguments-renamed
memory_mask: the mask for the memory sequence (optional).
decode: wether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
+ dropout_rate: the dropout probability value (optional)
Shape:
see the docs in Transformer class.
"""
@@ -730,61 +774,78 @@ def forward( # pylint: disable=arguments-renamed
x = tgt
if self.pre_ln:
sa_out, cache = self._sa_block(
- self.norm1(x),
- tgt_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index)
+ self.norm1(x),
+ tgt_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index,
+ dropout_rate=dropout_rate,
+ )
x = x + sa_out
- x = x + self._mha_block(self.norm2(x), memory, memory_mask)
- x = x + self._ff_block(self.norm3(x))
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate)
+ x = x + self._ff_block(self.norm3(x), dropout_rate)
else:
sa_out, cache = self._sa_block(
- x,
- tgt_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index)
+ x,
+ tgt_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index,
+ dropout_rate=dropout_rate,
+ )
x = self.norm1(x + sa_out)
- x = self.norm2(x + self._mha_block(x, memory, memory_mask))
- x = self.norm3(x + self._ff_block(x))
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate))
+ x = self.norm3(x + self._ff_block(x, dropout_rate))
return x, cache
# Self-attention block:
def _sa_block( # pylint: disable=arguments-renamed
- self,
- x: Tensor,
- attn_mask: Optional[Tensor],
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
x, cache = self.self_attn(
- x,
- attn_mask=attn_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index)
- return self.dropout1(x), cache
+ x,
+ attn_mask=attn_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index,
+ dropout_rate=dropout_rate,
+ )
+ return F.dropout(x, dropout_rate, self.training), cache
# Multihead attention block:
- def _mha_block(self, x: Tensor, mem: Tensor,
- attn_mask: Optional[Tensor]) -> Tensor:
- x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask)
- return self.dropout2(x)
+ def _mha_block(
+ self,
+ x: Tensor,
+ mem: Tensor,
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
+ x, _ = self.multihead_attn(
+ x, mem, attn_mask=attn_mask, dropout_rate=dropout_rate
+ )
+ return F.dropout(x, dropout_rate, self.training)
# Feed forward block.
- def _ff_block(self, inputs: Tensor) -> Tensor:
+ def _ff_block(
+ self, inputs: Tensor, dropout_rate: Optional[float] = 0.0
+ ) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
x = x * y
- x = self.linear2(self.dropout(x))
- return self.dropout3(x)
+ x = self.linear2(F.dropout(x, dropout_rate, self.training))
+ return F.dropout(x, dropout_rate, self.training)
class MultiheadAttention(nn.Module):
@@ -802,8 +863,6 @@ class MultiheadAttention(nn.Module):
``embed_dim // num_heads``).
self_attn: Whether self attention or encoder-decoder attention is used.
Default: ``True``.
- dropout_rate: Dropout probability on ``attn_output_weights``.
- Default: ``0.0`` (no dropout_rate).
bias: If specified, adds bias to input / output projection layers.
Default: ``False``.
device: The device of the module.
@@ -813,35 +872,38 @@ class MultiheadAttention(nn.Module):
>>> attn_output, cache = multihead_attn(x)
"""
- def __init__(self,
- embed_dim: int,
- num_heads: int,
- self_attn: bool = True,
- dropout_rate: float = 0.,
- attention_temp: float = 1.0,
- bias: bool = False,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None) -> None:
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ self_attn: bool = True,
+ attention_temp: float = 1.0,
+ bias: bool = False,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.self_attn = self_attn
- self.dropout = dropout_rate
self.head_dim = embed_dim // num_heads
self.attention_temp = attention_temp
- assert self.head_dim * num_heads == self.embed_dim, \
- 'embed_dim must be divisible by num_heads.'
+ assert self.head_dim * num_heads == self.embed_dim, (
+ 'embed_dim must be divisible by num_heads.'
+ )
factory_kwargs = {'device': device, 'dtype': dtype}
if self_attn:
# Self-attention.
self.in_proj = nn.Linear(
- embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
+ )
else:
# Encoder-decoder attention.
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.kv_proj = nn.Linear(
- embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
+ embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs
+ )
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
@@ -854,14 +916,17 @@ def _reset_parameters(self):
if module.bias is not None:
normal_(module.bias, std=1e-6)
- def forward(self,
- x: Tensor,
- mem: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ def forward(
+ self,
+ x: Tensor,
+ mem: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any: # TODO: (nico) remove default?!
r"""
Args:
x: Batch of input sequences of shape
@@ -869,7 +934,7 @@ def forward(self,
attention mechanism. See "Attention Is All You Need" for more details.
mem: Batch of input sequences of shape
(batch size, sequence length, embedding dimensionality) for
- encoder-decoder attention. See "Attention Is All You Need" for more
+ encoder-decoder attention. See "Attention Is All You Need" for more
details.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain
positions. Must be of shape :math:`(L, S)` or
@@ -887,6 +952,7 @@ def forward(self,
max_len: maximum sequence length, necessary for decoding cache.
cache: cache dictionary for autoregressive decoding.
index: index of the current decoding step, necessary for decoding cache.
+ dropout_rate: dropout probability on ``attn_output_weights``.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where
:math:`L` is the target sequence length, :math:`N` is the batch size,
@@ -911,16 +977,13 @@ def forward(self,
if decode:
if loc_cache is None:
loc_cache = {
- 'cached_key':
- torch.zeros((bsz, max_len, embed_dim),
- dtype=k.dtype,
- device=k.device),
- 'cached_value':
- torch.zeros((bsz, max_len, embed_dim),
- dtype=v.dtype,
- device=v.device),
- 'cache_index':
- torch.tensor(0, dtype=torch.long, device=k.device),
+ 'cached_key': torch.zeros(
+ (bsz, max_len, embed_dim), dtype=k.dtype, device=k.device
+ ),
+ 'cached_value': torch.zeros(
+ (bsz, max_len, embed_dim), dtype=v.dtype, device=v.device
+ ),
+ 'cache_index': torch.tensor(0, dtype=torch.long, device=k.device),
}
cached_key = loc_cache['cached_key']
cached_value = loc_cache['cached_value']
@@ -928,11 +991,13 @@ def forward(self,
# Shape check of cached keys against query input.
expected_shape = (bsz, 1, embed_dim)
if expected_shape != x.shape:
- raise ValueError('Autoregressive cache shape error, expected query '
- f'shape {expected_shape} instead got {x.shape}.')
+ raise ValueError(
+ 'Autoregressive cache shape error, expected query '
+ f'shape {expected_shape} instead got {x.shape}.'
+ )
# Update key, value caches with our new 1d spatial slices.
- cached_key[:, cache_index:cache_index + 1, :] = k
- cached_value[:, cache_index:cache_index + 1, :] = v
+ cached_key[:, cache_index : cache_index + 1, :] = k
+ cached_value[:, cache_index : cache_index + 1, :] = v
k = cached_key
v = cached_value
cache_index += 1
@@ -942,8 +1007,9 @@ def forward(self,
# not the remaining zero elements.
if attn_mask is not None:
raise ValueError('Attention mask has to be None for decode == True.')
- attn_mask = (torch.arange(max_len, device=k.device) >=
- cache_index).reshape(1, max_len)
+ attn_mask = (
+ torch.arange(max_len, device=k.device) >= cache_index
+ ).reshape(1, max_len)
# Update sequence length to account for complete sequence.
seq_len = k.size(1)
@@ -955,17 +1021,21 @@ def forward(self,
# Check dtype and shape of attention mask.
if not decode and attn_mask is not None:
- assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
- f'Float and bool dtypes are supported, not {attn_mask.dtype}.'
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, (
+ f'Float and bool dtypes are supported, not {attn_mask.dtype}.'
+ )
# Ensure attn_mask's dim is 3.
if attn_mask.dim() == 3:
correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len)
if attn_mask.shape != correct_3d_size:
- raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, '
- f'but should be {correct_3d_size}.')
+ raise RuntimeError(
+ f'The shape of attn_mask is {attn_mask.shape}, '
+ f'but should be {correct_3d_size}.'
+ )
else:
raise RuntimeError(
- f"attn_mask's dimension {attn_mask.dim()} is not supported")
+ f"attn_mask's dimension {attn_mask.dim()} is not supported"
+ )
# Reshape attention mask to be consistent with q, k, v.
attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len)
@@ -976,15 +1046,17 @@ def forward(self,
attn_mask = new_attn_mask
# Adjust dropout_rate probability.
- dropout_rate = self.dropout if self.training else 0.0
+ attn_dropout_rate = dropout_rate if self.training else 0.0
# Calculate attention.
q = self.attention_temp * q
attn_output = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask, dropout_rate)
+ q, k, v, attn_mask, attn_dropout_rate
+ )
# Rearrange for output projection.
- attn_output = attn_output.transpose(1, 2).contiguous().view(
- bsz, tgt_len, embed_dim)
+ attn_output = (
+ attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
+ )
# Output projection.
attn_output = self.out_proj(attn_output)
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index d0716d6c8..53d95d393 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -3,20 +3,18 @@
import contextlib
from typing import Any, Dict, Optional, Tuple
-from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist
-from torch.nn import DataParallel as DP
import torch.nn.functional as F
+from absl import logging
+from torch.nn import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
+from algoperf import param_utils, pytorch_utils, spec
from algoperf.workloads.wmt import bleu
-from algoperf.workloads.wmt.wmt_pytorch import decode
+from algoperf.workloads.wmt.wmt_pytorch import decode, models
from algoperf.workloads.wmt.wmt_pytorch.models import Transformer
from algoperf.workloads.wmt.workload import BaseWmtWorkload
@@ -27,11 +25,12 @@ class WmtWorkload(BaseWmtWorkload):
"""WMT PyTorch workload."""
def compute_weighted_cross_entropy(
- self,
- logits: spec.Tensor,
- targets: spec.Tensor,
- weights: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ logits: spec.Tensor,
+ targets: spec.Tensor,
+ weights: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.1,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
@@ -46,11 +45,14 @@ def compute_weighted_cross_entropy(
valid examples in batch, 'per_example': 1-d array of per-example losses}
"""
if logits.ndim != targets.ndim + 1:
- raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and '
- f'{targets.shape} targets.')
+ raise ValueError(
+ f'Incorrect shapes. Got shape {logits.shape} logits and '
+ f'{targets.shape} targets.'
+ )
loss_fn = torch.nn.CrossEntropyLoss(
- reduction='none', label_smoothing=label_smoothing)
+ reduction='none', label_smoothing=label_smoothing
+ )
if N_GPUS > 1 and not USE_PYTORCH_DDP:
loss_fn = DP(loss_fn)
@@ -59,24 +61,27 @@ def compute_weighted_cross_entropy(
if weights is None:
weights = torch.ones_like(targets)
per_example_losses = torch.where(
- weights.to(torch.bool), per_example_losses, 0.)
+ weights.to(torch.bool), per_example_losses, 0.0
+ )
summed_loss = per_example_losses.sum()
n_valid_examples = weights.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
# Primary eval / decode step functions.
# ----------------------------------------------------------------------------
@torch.no_grad()
- def predict_step(self,
- inputs: spec.Tensor,
- params: spec.ParameterContainer,
- eos_id: int,
- max_decode_len: int,
- beam_size: int = 4) -> spec.Tensor:
+ def predict_step(
+ self,
+ inputs: spec.Tensor,
+ params: spec.ParameterContainer,
+ eos_id: int,
+ max_decode_len: int,
+ beam_size: int = 4,
+ ) -> spec.Tensor:
"""Predict translation with fast decoding beam search on a batch."""
# params = params.module if isinstance(params, (DP, DDP)) else params
if hasattr(params, 'module'):
@@ -85,8 +90,8 @@ def predict_step(self,
if hasattr(params, '_modules'):
params = params._modules
- encoder = params["encoder"]
- decoder = params["decoder"]
+ encoder = params['encoder']
+ decoder = params['decoder']
else:
encoder = params.encoder
decoder = params.decoder
@@ -97,21 +102,23 @@ def predict_step(self,
decoder = DP(decoder)
encoded_inputs = torch.repeat_interleave(
- encoder(inputs), repeats=beam_size, dim=0)
+ encoder(inputs), repeats=beam_size, dim=0
+ )
raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0)
def tokens_ids_to_logits(
- flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
+ flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]:
"""Token slice to logits from decoder model."""
# --> [batch * beam, 1, vocab]
flat_logits, new_flat_cache = decoder(
- flat_ids,
- encoded_inputs,
- raw_inputs,
- decode=True,
- max_len=max_decode_len,
- cache=flat_cache)
+ flat_ids,
+ encoded_inputs,
+ raw_inputs,
+ decode=True,
+ max_len=max_decode_len,
+ cache=flat_cache,
+ )
# Remove singleton sequence-length dimension:
# [batch * beam, 1, vocab] --> [batch * beam, vocab]
flat_logits = flat_logits.squeeze(dim=1)
@@ -120,24 +127,27 @@ def tokens_ids_to_logits(
# Using the above-defined single-step decoder function, run a
# beam search over possible sequences given input encoding.
beam_seqs, _ = decode.beam_search(
- inputs,
- None,
- tokens_ids_to_logits,
- beam_size=beam_size,
- alpha=0.6,
- eos_id=eos_id,
- max_decode_len=max_decode_len)
+ inputs,
+ None,
+ tokens_ids_to_logits,
+ beam_size=beam_size,
+ alpha=0.6,
+ eos_id=eos_id,
+ max_decode_len=max_decode_len,
+ )
# Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
# sorted in increasing order of log-probability.
# Return the highest scoring beam sequence, drop first dummy 0 token.
return beam_seqs[:, -1, 1:]
- def translate_and_calculate_bleu(self,
- params: spec.ParameterContainer,
- ds_iter: tf.data.Dataset,
- num_batches: int,
- max_predict_length: int):
+ def translate_and_calculate_bleu(
+ self,
+ params: spec.ParameterContainer,
+ ds_iter: tf.data.Dataset,
+ num_batches: int,
+ max_predict_length: int,
+ ):
"""Translates the `ds_iter` and calculates the BLEU score."""
logging.info('Translating evaluation dataset.')
references, predictions = [], []
@@ -145,10 +155,9 @@ def translate_and_calculate_bleu(self,
pred_batch = next(ds_iter)
inputs = pred_batch['inputs']
targets = pred_batch['targets']
- predicted = self.predict_step(inputs,
- params,
- decode.EOS_ID,
- max_predict_length)
+ predicted = self.predict_step(
+ inputs, params, decode.EOS_ID, max_predict_length
+ )
# Find actual batch size, ignoring the potential padding.
weights = pred_batch.get('weights')
@@ -165,12 +174,7 @@ def translate_and_calculate_bleu(self,
bleu_score = bleu.corpus_bleu(predictions, [references]).score
return bleu_score
- def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is used as attention_dropout_rate."""
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.activation == 'relu':
@@ -181,12 +185,11 @@ def init_model_fn(
raise ValueError(f'Unknown activation function {self.activation}.')
model = Transformer(
- dropout_rate=dropout_rate,
- attention_dropout_rate=aux_dropout_rate,
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -201,13 +204,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'shared_embedding.weight'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -218,43 +223,53 @@ def model_fn(
model.eval()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(
- src=augmented_and_preprocessed_input_batch['inputs'],
- tgt=augmented_and_preprocessed_input_batch['targets'],
- inputs_positions=augmented_and_preprocessed_input_batch.get(
- 'inputs_position', None),
- targets_positions=augmented_and_preprocessed_input_batch.get(
- 'targets_position', None),
- inputs_segmentation=augmented_and_preprocessed_input_batch.get(
- 'inputs_segmentation', None),
- targets_segmentation=augmented_and_preprocessed_input_batch.get(
- 'targets_segmentation', None))
+ src=augmented_and_preprocessed_input_batch['inputs'],
+ tgt=augmented_and_preprocessed_input_batch['targets'],
+ inputs_positions=augmented_and_preprocessed_input_batch.get(
+ 'inputs_position', None
+ ),
+ targets_positions=augmented_and_preprocessed_input_batch.get(
+ 'targets_position', None
+ ),
+ inputs_segmentation=augmented_and_preprocessed_input_batch.get(
+ 'inputs_segmentation', None
+ ),
+ targets_segmentation=augmented_and_preprocessed_input_batch.get(
+ 'targets_segmentation', None
+ ),
+ dropout_rate=dropout_rate,
+ )
return logits_batch, None
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False):
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+ ):
per_device_batch_size = int(global_batch_size / N_GPUS)
n_inputs = 7 if split == 'train' else 3
# The input pipeline has to be created in all processes, because
# self._tokenizer has to be available in every process.
- np_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size,
- num_batches,
- repeat_final_dataset)
+ np_iter = super()._build_input_queue(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ num_batches,
+ repeat_final_dataset,
+ )
# We only need np_iter in one Python process.
if RANK != 0:
del np_iter
@@ -269,14 +284,15 @@ def _build_input_queue(self,
tensor = torch.as_tensor(value, dtype=torch.int64, device=DEVICE)
tensor_list.append(tensor)
batch[key] = (
- tensor[0] if USE_PYTORCH_DDP else tensor.view(
- -1, value.shape[-1]))
+ tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, value.shape[-1])
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
# During eval, the batch size of the remainder might be different.
if split != 'train':
per_device_batch_size = torch.tensor(
- len(batch['inputs']), dtype=torch.int32, device=DEVICE)
+ len(batch['inputs']), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
# We don't need to broadcast the batch for the device with RANK == 0.
dist.broadcast(torch.stack(tensor_list)[:, 1:].contiguous(), src=0)
@@ -284,25 +300,27 @@ def _build_input_queue(self,
batch = {}
# During eval, the batch size of the remainder might be different.
if split != 'train':
- per_device_batch_size = torch.empty((1,),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (1,), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
# N_GPUS - 1 since we don't broadcast the batch for RANK == 0.
- tensor = torch.empty((n_inputs, N_GPUS - 1, per_device_batch_size, 256),
- dtype=torch.int64,
- device=DEVICE)
+ tensor = torch.empty(
+ (n_inputs, N_GPUS - 1, per_device_batch_size, 256),
+ dtype=torch.int64,
+ device=DEVICE,
+ )
dist.broadcast(tensor, src=0)
# Note that the order of the keys is important.
if split == 'train':
keys = [
- 'inputs',
- 'inputs_position',
- 'inputs_segmentation',
- 'targets',
- 'targets_position',
- 'targets_segmentation',
- 'weights',
+ 'inputs',
+ 'inputs_position',
+ 'inputs_segmentation',
+ 'targets',
+ 'targets_position',
+ 'targets_segmentation',
+ 'weights',
]
# For all eval/test splits.
else:
@@ -312,34 +330,35 @@ def _build_input_queue(self,
batch[key] = tensor[n][RANK - 1]
yield batch
- def eval_step(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ def eval_step(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> Dict[str, spec.Tensor]:
"""Calculate evaluation metrics on a batch."""
targets = batch['targets']
weights = batch['weights']
logits, _ = self.model_fn(
- params,
- batch,
- mode=spec.ForwardPassMode.EVAL,
- model_state=None,
- rng=None,
- update_batch_norm=False)
- summed_loss = self.compute_weighted_cross_entropy(logits,
- targets,
- weights,
- 0.0)['summed']
+ params,
+ batch,
+ mode=spec.ForwardPassMode.EVAL,
+ model_state=None,
+ rng=None,
+ update_batch_norm=False,
+ )
+ summed_loss = self.compute_weighted_cross_entropy(
+ logits, targets, weights, 0.0
+ )['summed']
acc_sum, weight_sum = self.compute_weighted_accuracy(
- logits, targets, weights)
+ logits, targets, weights
+ )
return {
- 'loss': summed_loss,
- 'accuracy': acc_sum,
- 'denominator': weight_sum,
+ 'loss': summed_loss,
+ 'accuracy': acc_sum,
+ 'denominator': weight_sum,
}
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
if USE_PYTORCH_DDP:
diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py
index 51b33373d..40e4262dd 100644
--- a/algoperf/workloads/wmt/workload.py
+++ b/algoperf/workloads/wmt/workload.py
@@ -60,8 +60,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -115,23 +116,26 @@ def activation(self) -> str:
def glu(self) -> bool:
return False
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False):
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+ ):
is_training = split == 'train'
ds, self._tokenizer = input_pipeline.get_wmt_dataset(
- data_rng,
- split,
- data_dir,
- is_training=is_training,
- vocab_size=self._vocab_size,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=repeat_final_dataset)
+ data_rng,
+ split,
+ data_dir,
+ is_training=is_training,
+ vocab_size=self._vocab_size,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=repeat_final_dataset,
+ )
# Separate function is necessary because the code above has to be executed
# when _build_input_queue is called (not when next() is first called on it).
@@ -148,19 +152,21 @@ def _input_queue_generator():
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -168,12 +174,13 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- rng,
- split,
- data_dir,
- global_batch_size,
- num_batches,
- repeat_final_dataset=True)
+ rng,
+ split,
+ data_dir,
+ global_batch_size,
+ num_batches,
+ repeat_final_dataset=True,
+ )
eval_metrics = {}
for _ in range(num_batches):
@@ -186,16 +193,17 @@ def _eval_model_on_split(self,
eval_results = self._normalize_eval_metrics(num_examples, eval_metrics)
eval_results['bleu'] = self.translate_and_calculate_bleu(
- params=params,
- ds_iter=self._eval_iters[split],
- num_batches=num_batches,
- max_predict_length=256)
+ params=params,
+ ds_iter=self._eval_iters[split],
+ num_batches=num_batches,
+ max_predict_length=256,
+ )
return eval_results
def compute_weighted_accuracy(
- self, logits: spec.Tensor, targets: spec.Tensor,
- weights: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]:
+ self, logits: spec.Tensor, targets: spec.Tensor, weights: spec.Tensor
+ ) -> Tuple[spec.Tensor, spec.Tensor]:
"""Compute weighted accuracy for log probs and targets.
Args:
@@ -207,8 +215,10 @@ def compute_weighted_accuracy(
Tuple of scalar summed accuracy and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
- raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and '
- f'{targets.shape} targets.')
+ raise ValueError(
+ f'Incorrect shapes. Got shape {logits.shape} logits and '
+ f'{targets.shape} targets.'
+ )
accuracy = (logits.argmax(-1) == targets) * weights
normalizing_factor = weights.sum()
return accuracy.sum(), normalizing_factor
@@ -216,17 +226,18 @@ def compute_weighted_accuracy(
def _decode_tokens(self, toks: spec.Tensor) -> spec.Tensor:
if isinstance(toks, torch.Tensor):
toks = toks.cpu().numpy()
- valid_toks = toks[:np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32)
+ valid_toks = toks[: np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32)
return self._tokenizer.detokenize(valid_toks).numpy().decode('utf-8')
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -234,7 +245,8 @@ def loss_fn(
(not synced across devices).
"""
return self.compute_weighted_cross_entropy(
- logits_batch,
- label_batch,
- weights=mask_batch,
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ weights=mask_batch,
+ label_smoothing=label_smoothing,
+ )
diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py
index 4712f4e25..4dd4717e9 100644
--- a/algoperf/workloads/workloads.py
+++ b/algoperf/workloads/workloads.py
@@ -1,5 +1,5 @@
-""" Registry of workload info
-"""
+"""Registry of workload info"""
+
import importlib
import inspect
import os
@@ -9,149 +9,151 @@
BASE_WORKLOADS_DIR = 'algoperf/workloads/'
WORKLOADS = {
- 'cifar': {
- 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload'
- },
- 'criteo1tb': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallWorkload',
- },
- 'criteo1tb_test': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload',
- },
- 'criteo1tb_layernorm': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload'
- },
- 'criteo1tb_embed_init': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload'
- },
- 'criteo1tb_resnet': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload'
- },
- 'fastmri': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRIWorkload',
- },
- 'fastmri_model_size': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRIModelSizeWorkload',
- },
- 'fastmri_tanh': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRITanhWorkload',
- },
- 'fastmri_layernorm': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRILayerNormWorkload',
- },
- 'imagenet_resnet': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetWorkload',
- },
- 'imagenet_resnet_silu': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetSiLUWorkload',
- },
- 'imagenet_resnet_gelu': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetGELUWorkload',
- },
- 'imagenet_resnet_large_bn_init': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload',
- },
- 'imagenet_vit': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitWorkload',
- },
- 'imagenet_vit_glu': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitGluWorkload',
- },
- 'imagenet_vit_post_ln': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitPostLNWorkload',
- },
- 'imagenet_vit_map': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitMapWorkload',
- },
- 'librispeech_conformer': {
- 'workload_path': 'librispeech_conformer/librispeech',
- 'workload_class_name': 'LibriSpeechConformerWorkload',
- },
- 'librispeech_conformer_attention_temperature': {
- 'workload_path':
- 'librispeech_conformer/librispeech',
- 'workload_class_name':
- 'LibriSpeechConformerAttentionTemperatureWorkload',
- },
- 'librispeech_conformer_layernorm': {
- 'workload_path': 'librispeech_conformer/librispeech',
- 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload',
- },
- 'librispeech_conformer_gelu': {
- 'workload_path': 'librispeech_conformer/librispeech',
- 'workload_class_name': 'LibriSpeechConformerGeluWorkload',
- },
- 'librispeech_deepspeech': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechWorkload',
- },
- 'librispeech_deepspeech_tanh': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload',
- },
- 'librispeech_deepspeech_no_resnet': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload',
- },
- 'librispeech_deepspeech_norm_and_spec_aug': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload',
- },
- 'mnist': {
- 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload'
- },
- 'ogbg': {
- 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'
- },
- 'ogbg_gelu': {
- 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload'
- },
- 'ogbg_silu': {
- 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload'
- },
- 'ogbg_model_size': {
- 'workload_path': 'ogbg/ogbg',
- 'workload_class_name': 'OgbgModelSizeWorkload'
- },
- 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'},
- 'wmt_post_ln': {
- 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN'
- },
- 'wmt_attention_temp': {
- 'workload_path': 'wmt/wmt',
- 'workload_class_name': 'WmtWorkloadAttentionTemp'
- },
- 'wmt_glu_tanh': {
- 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH'
- },
+ 'cifar': {
+ 'workload_path': 'cifar/cifar',
+ 'workload_class_name': 'CifarWorkload',
+ },
+ 'criteo1tb': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallWorkload',
+ },
+ 'criteo1tb_test': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload',
+ },
+ 'criteo1tb_layernorm': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload',
+ },
+ 'criteo1tb_embed_init': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload',
+ },
+ 'criteo1tb_resnet': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload',
+ },
+ 'fastmri': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRIWorkload',
+ },
+ 'fastmri_model_size': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRIModelSizeWorkload',
+ },
+ 'fastmri_tanh': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRITanhWorkload',
+ },
+ 'fastmri_layernorm': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRILayerNormWorkload',
+ },
+ 'imagenet_resnet': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetWorkload',
+ },
+ 'imagenet_resnet_silu': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetSiLUWorkload',
+ },
+ 'imagenet_resnet_gelu': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetGELUWorkload',
+ },
+ 'imagenet_resnet_large_bn_init': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload',
+ },
+ 'imagenet_vit': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitWorkload',
+ },
+ 'imagenet_vit_glu': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitGluWorkload',
+ },
+ 'imagenet_vit_post_ln': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitPostLNWorkload',
+ },
+ 'imagenet_vit_map': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitMapWorkload',
+ },
+ 'librispeech_conformer': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerWorkload',
+ },
+ 'librispeech_conformer_attention_temperature': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload',
+ },
+ 'librispeech_conformer_layernorm': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload',
+ },
+ 'librispeech_conformer_gelu': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerGeluWorkload',
+ },
+ 'librispeech_deepspeech': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechWorkload',
+ },
+ 'librispeech_deepspeech_tanh': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload',
+ },
+ 'librispeech_deepspeech_no_resnet': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload',
+ },
+ 'librispeech_deepspeech_norm_and_spec_aug': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload',
+ },
+ 'mnist': {
+ 'workload_path': 'mnist/mnist',
+ 'workload_class_name': 'MnistWorkload',
+ },
+ 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'},
+ 'ogbg_gelu': {
+ 'workload_path': 'ogbg/ogbg',
+ 'workload_class_name': 'OgbgGeluWorkload',
+ },
+ 'ogbg_silu': {
+ 'workload_path': 'ogbg/ogbg',
+ 'workload_class_name': 'OgbgSiluWorkload',
+ },
+ 'ogbg_model_size': {
+ 'workload_path': 'ogbg/ogbg',
+ 'workload_class_name': 'OgbgModelSizeWorkload',
+ },
+ 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'},
+ 'wmt_post_ln': {
+ 'workload_path': 'wmt/wmt',
+ 'workload_class_name': 'WmtWorkloadPostLN',
+ },
+ 'wmt_attention_temp': {
+ 'workload_path': 'wmt/wmt',
+ 'workload_class_name': 'WmtWorkloadAttentionTemp',
+ },
+ 'wmt_glu_tanh': {
+ 'workload_path': 'wmt/wmt',
+ 'workload_class_name': 'WmtWorkloadGLUTanH',
+ },
}
BASE_WORKLOADS = [
- 'criteo1tb',
- 'fastmri',
- 'imagenet_resnet',
- 'imagenet_vit',
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'ogbg',
- 'wmt'
+ 'criteo1tb',
+ 'fastmri',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'ogbg',
+ 'wmt',
]
@@ -171,10 +173,12 @@ def convert_filepath_to_module(path: str):
return base.replace('/', '.')
-def import_workload(workload_path: str,
- workload_class_name: str,
- return_class=False,
- workload_init_kwargs=None) -> spec.Workload:
+def import_workload(
+ workload_path: str,
+ workload_class_name: str,
+ return_class=False,
+ workload_init_kwargs=None,
+) -> spec.Workload:
"""Import and add the workload to the registry.
This importlib loading is nice to have because it allows runners to avoid
@@ -206,9 +210,10 @@ def import_workload(workload_path: str,
break
if workload_class is None:
raise ValueError(
- f'Could not find member {workload_class_name} in {workload_path}. '
- 'Make sure the Workload class is spelled correctly and defined in '
- 'the top scope of the module.')
+ f'Could not find member {workload_class_name} in {workload_path}. '
+ 'Make sure the Workload class is spelled correctly and defined in '
+ 'the top scope of the module.'
+ )
if return_class:
return workload_class
return workload_class(**workload_init_kwargs)
diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py
index efe923dbe..e110930cd 100644
--- a/datasets/dataset_setup.py
+++ b/datasets/dataset_setup.py
@@ -72,8 +72,7 @@
from torchvision.datasets import CIFAR10
from algoperf.workloads.wmt import tokenizer
-from algoperf.workloads.wmt.input_pipeline import \
- normalize_feature_names
+from algoperf.workloads.wmt.input_pipeline import normalize_feature_names
from datasets import librispeech_preprocess
from datasets import librispeech_tokenizer
@@ -101,84 +100,96 @@
FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz'
flags.DEFINE_boolean(
- 'interactive_deletion',
- True,
- 'If true, user will be prompted before any files are deleted. If false, no '
- 'files will be deleted.')
+ 'interactive_deletion',
+ True,
+ 'If true, user will be prompted before any files are deleted. If false, no '
+ 'files will be deleted.',
+)
flags.DEFINE_boolean(
- 'all',
- False,
- 'Whether or not to download all datasets. If false, can download some '
- 'combination of datasets by setting the individual dataset flags below.')
-
-flags.DEFINE_boolean('criteo1tb',
- False,
- 'If --all=false, whether or not to download Criteo 1TB.')
-flags.DEFINE_boolean('cifar',
- False,
- 'If --all=false, whether or not to download CIFAR-10.')
-flags.DEFINE_boolean('fastmri',
- False,
- 'If --all=false, whether or not to download FastMRI.')
-flags.DEFINE_boolean('imagenet',
- False,
- 'If --all=false, whether or not to download Imagenet.')
-flags.DEFINE_boolean('librispeech',
- False,
- 'If --all=false, whether or not to download LibriSpeech.')
-flags.DEFINE_boolean('mnist',
- False,
- 'If --all=false, whether or not to download MNIST.')
-flags.DEFINE_boolean('ogbg',
- False,
- 'If --all=false, whether or not to download OGBG.')
-flags.DEFINE_boolean('wmt',
- False,
- 'If --all=false, whether or not to download WMT.')
+ 'all',
+ False,
+ 'Whether or not to download all datasets. If false, can download some '
+ 'combination of datasets by setting the individual dataset flags below.',
+)
+
+flags.DEFINE_boolean(
+ 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.'
+)
+flags.DEFINE_boolean(
+ 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.'
+)
+flags.DEFINE_boolean(
+ 'fastmri', False, 'If --all=false, whether or not to download FastMRI.'
+)
+flags.DEFINE_boolean(
+ 'imagenet', False, 'If --all=false, whether or not to download Imagenet.'
+)
+flags.DEFINE_boolean(
+ 'librispeech',
+ False,
+ 'If --all=false, whether or not to download LibriSpeech.',
+)
+flags.DEFINE_boolean(
+ 'mnist', False, 'If --all=false, whether or not to download MNIST.'
+)
+flags.DEFINE_boolean(
+ 'ogbg', False, 'If --all=false, whether or not to download OGBG.'
+)
+flags.DEFINE_boolean(
+ 'wmt', False, 'If --all=false, whether or not to download WMT.'
+)
flags.DEFINE_string(
- 'data_dir',
- '~/data',
- 'The path to the folder where datasets should be downloaded.')
+ 'data_dir',
+ '~/data',
+ 'The path to the folder where datasets should be downloaded.',
+)
flags.DEFINE_string(
- 'temp_dir',
- '/tmp/mlcommons',
- 'A local path to a folder where temp files can be downloaded.')
+ 'temp_dir',
+ '/tmp/mlcommons',
+ 'A local path to a folder where temp files can be downloaded.',
+)
flags.DEFINE_string(
- 'imagenet_train_url',
- None,
- 'Only necessary if you want this script to `wget` the ImageNet train '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'imagenet_train_url',
+ None,
+ 'Only necessary if you want this script to `wget` the ImageNet train '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'imagenet_val_url',
- None,
- 'Only necessary if you want this script to `wget` the ImageNet validation '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'imagenet_val_url',
+ None,
+ 'Only necessary if you want this script to `wget` the ImageNet validation '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'fastmri_knee_singlecoil_train_url',
- None,
- 'Only necessary if you want this script to `wget` the FastMRI train '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'fastmri_knee_singlecoil_train_url',
+ None,
+ 'Only necessary if you want this script to `wget` the FastMRI train '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'fastmri_knee_singlecoil_val_url',
- None,
- 'Only necessary if you want this script to `wget` the FastMRI validation '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'fastmri_knee_singlecoil_val_url',
+ None,
+ 'Only necessary if you want this script to `wget` the FastMRI validation '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'fastmri_knee_singlecoil_test_url',
- None,
- 'Only necessary if you want this script to `wget` the FastMRI test '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'fastmri_knee_singlecoil_test_url',
+ None,
+ 'Only necessary if you want this script to `wget` the FastMRI test '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_integer(
- 'num_decompression_threads',
- 8,
- 'The number of threads to use in parallel when decompressing.')
+ 'num_decompression_threads',
+ 8,
+ 'The number of threads to use in parallel when decompressing.',
+)
flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.')
@@ -186,7 +197,7 @@
FLAGS = flags.FLAGS
-os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
def _maybe_mkdir(d):
@@ -198,8 +209,10 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion):
if not interactive_deletion:
return
files_for_deletion = '\n'.join(paths)
- logging.info('\n\n\nWARNING: the following temp files will be DELETED:'
- f'\n{files_for_deletion}')
+ logging.info(
+ '\n\n\nWARNING: the following temp files will be DELETED:'
+ f'\n{files_for_deletion}'
+ )
delete_str = input('Confirm deletion? [y/N]: ')
if delete_str.lower() == 'y':
del_cmd = 'rm ' + ' '.join(f'"{s}"' for s in paths)
@@ -225,8 +238,9 @@ def _download_url(url, data_dir, name=None):
if os.path.exists(file_path):
while True:
- overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
- file_path)).lower()
+ overwrite = input(
+ 'File already exists {}.\n Overwrite? (Y/n)'.format(file_path)
+ ).lower()
if overwrite in ['y', 'n']:
break
logging.info('Invalid response. Try again.')
@@ -240,17 +254,18 @@ def _download_url(url, data_dir, name=None):
progress_bar.update(chunk_size_in_mib)
f.write(chunk)
progress_bar.close()
- if (progress_bar.total != 0 and progress_bar.n != progress_bar.total):
+ if progress_bar.total != 0 and progress_bar.n != progress_bar.total:
raise RuntimeError(
- ('Download corrupted, size {n} MiB from {url} does not match '
- 'expected size {size} MiB').format(
- url=url, n=progress_bar.n, size=progress_bar.total))
+ (
+ 'Download corrupted, size {n} MiB from {url} does not match '
+ 'expected size {size} MiB'
+ ).format(url=url, n=progress_bar.n, size=progress_bar.total)
+ )
-def download_criteo1tb(data_dir,
- tmp_dir,
- num_decompression_threads,
- interactive_deletion):
+def download_criteo1tb(
+ data_dir, tmp_dir, num_decompression_threads, interactive_deletion
+):
criteo_dir = os.path.join(data_dir, 'criteo1tb')
tmp_criteo_dir = os.path.join(tmp_dir, 'criteo1tb')
_maybe_mkdir(criteo_dir)
@@ -258,47 +273,56 @@ def download_criteo1tb(data_dir,
# Forked from
# https://github.com/iamleot/transferwee/blob/master/transferwee.py.
- user_agent = ('Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) '
- 'Gecko/20100101 Firefox/102.0')
+ user_agent = (
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) '
+ 'Gecko/20100101 Firefox/102.0'
+ )
criteo_wetransfer_url = (
- 'https://criteo.wetransfer.com/downloads/'
- '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2')
+ 'https://criteo.wetransfer.com/downloads/'
+ '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2'
+ )
_, _, transfer_id, security_hash = urllib.parse.urlparse(
- criteo_wetransfer_url).path.split('/')
+ criteo_wetransfer_url
+ ).path.split('/')
session = requests.Session()
- session.headers.update({
+ session.headers.update(
+ {
'User-Agent': user_agent,
'x-requested-with': 'XMLHttpRequest',
- })
+ }
+ )
r = session.get('https://wetransfer.com/')
m = re.search('name="csrf-token" content="([^"]+)"', r.text)
if m:
session.headers.update({'x-csrf-token': m.group(1)})
get_url_request = session.post(
- f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download',
- json={
- 'intent': 'entire_transfer',
- 'security_hash': security_hash,
- })
+ f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download',
+ json={
+ 'intent': 'entire_transfer',
+ 'security_hash': security_hash,
+ },
+ )
session.close()
download_url = get_url_request.json().get('direct_link')
logging.info(f'Downloading ~342GB Criteo 1TB data .zip file:\n{download_url}')
download_request = requests.get( # pylint: disable=missing-timeout
- download_url,
- headers={'User-Agent': user_agent},
- stream=True)
+ download_url, headers={'User-Agent': user_agent}, stream=True
+ )
all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip')
if not FLAGS.skip_download:
download = True
if os.path.exists(all_days_zip_filepath):
while True:
- overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
- all_days_zip_filepath)).lower()
+ overwrite = input(
+ 'File already exists {}.\n Overwrite? (Y/n)'.format(
+ all_days_zip_filepath
+ )
+ ).lower()
if overwrite in ['y', 'n']:
break
logging.info('Invalid response. Try again.')
@@ -324,8 +348,10 @@ def download_criteo1tb(data_dir,
input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz')
gz_paths.append(input_path)
unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv')
- unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > '
- f'"{unzipped_path}"')
+ unzip_cmd = (
+ f'pigz -d -c -p{num_decompression_threads} "{input_path}" > '
+ f'"{unzipped_path}"'
+ )
logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}')
processes.append(subprocess.Popen(unzip_cmd, shell=True))
for p in processes:
@@ -341,8 +367,7 @@ def download_criteo1tb(data_dir,
unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv')
unzipped_paths.append(unzipped_path)
split_path = os.path.join(criteo_dir, f'day_{day}_')
- split_cmd = ('split -a 2 -d -l 5000000 '
- f'"{unzipped_path}" "{split_path}"')
+ split_cmd = f'split -a 2 -d -l 5000000 "{unzipped_path}" "{split_path}"'
logging.info(f'Running Criteo 1TB split command:\n{split_cmd}')
batch_processes.append(subprocess.Popen(split_cmd, shell=True))
for p in batch_processes:
@@ -362,45 +387,50 @@ def download_cifar(data_dir, framework):
def extract_filename_from_url(url, start_str='knee', end_str='.xz'):
- """ The url filenames are sometimes couched within a urldefense+aws access id
+ """The url filenames are sometimes couched within a urldefense+aws access id
etc. string. Unfortunately querying the content disposition in requests fails
(not provided)... so fast search is done here within the url.
- """
+ """
failure = -1
start = url.find(start_str)
end = url.find(end_str)
if failure in (start, end):
raise ValueError(
- f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}')
+ f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}'
+ )
end += len(end_str) # make it inclusive
return url[start:end]
-def download_fastmri(data_dir,
- fastmri_train_url,
- fastmri_val_url,
- fastmri_test_url):
+def download_fastmri(
+ data_dir, fastmri_train_url, fastmri_val_url, fastmri_test_url
+):
data_dir = os.path.join(data_dir, 'fastmri')
# Download fastmri train dataset
knee_train_filename = extract_filename_from_url(fastmri_train_url)
logging.info(
- 'Downloading fastmri train dataset from {}'.format(fastmri_train_url))
+ 'Downloading fastmri train dataset from {}'.format(fastmri_train_url)
+ )
_download_url(
- url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename)
+ url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename
+ )
# Download fastmri val dataset
knee_val_filename = extract_filename_from_url(fastmri_val_url)
logging.info(
- 'Downloading fastmri val dataset from {}'.format(fastmri_val_url))
+ 'Downloading fastmri val dataset from {}'.format(fastmri_val_url)
+ )
_download_url(url=fastmri_val_url, data_dir=data_dir, name=knee_val_filename)
# Download fastmri test dataset
knee_test_filename = extract_filename_from_url(fastmri_test_url)
logging.info(
- 'Downloading fastmri test dataset from {}'.format(fastmri_test_url))
+ 'Downloading fastmri test dataset from {}'.format(fastmri_test_url)
+ )
_download_url(
- url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename)
+ url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename
+ )
return data_dir
@@ -432,18 +462,18 @@ def setup_fastmri(data_dir):
# Rename folders to match what the workload expects
os.rename(
- os.path.join(data_dir, "singlecoil_train"),
- os.path.join(data_dir, "knee_singlecoil_train"),
+ os.path.join(data_dir, 'singlecoil_train'),
+ os.path.join(data_dir, 'knee_singlecoil_train'),
)
os.rename(
- os.path.join(data_dir, "singlecoil_val"),
- os.path.join(data_dir, "knee_singlecoil_val"),
+ os.path.join(data_dir, 'singlecoil_val'),
+ os.path.join(data_dir, 'knee_singlecoil_val'),
)
os.rename(
- os.path.join(data_dir, "singlecoil_test"),
- os.path.join(data_dir, "knee_singlecoil_test"),
+ os.path.join(data_dir, 'singlecoil_test'),
+ os.path.join(data_dir, 'knee_singlecoil_test'),
)
- logging.info("Set up fastMRI dataset complete")
+ logging.info('Set up fastMRI dataset complete')
def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url):
@@ -456,26 +486,32 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url):
# been moved to the manual_download_dir.
# Get paths in manual_download_dir.
imagenet_jax_data_dir = os.path.join(data_dir, 'jax')
- manual_download_dir = os.path.join(imagenet_jax_data_dir,
- 'downloads',
- 'manual')
- imagenet_train_download_filepath = os.path.join(manual_download_dir,
- IMAGENET_TRAIN_TAR_FILENAME)
- imagenet_val_download_filepath = os.path.join(manual_download_dir,
- IMAGENET_VAL_TAR_FILENAME)
+ manual_download_dir = os.path.join(
+ imagenet_jax_data_dir, 'downloads', 'manual'
+ )
+ imagenet_train_download_filepath = os.path.join(
+ manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME
+ )
+ imagenet_val_download_filepath = os.path.join(
+ manual_download_dir, IMAGENET_VAL_TAR_FILENAME
+ )
# Download imagenet train dataset
if not os.path.exists(imagenet_train_filepath) and not os.path.exists(
- imagenet_train_download_filepath):
+ imagenet_train_download_filepath
+ ):
logging.info(
- 'Downloading imagenet train dataset from {}'.format(imagenet_train_url))
+ 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)
+ )
_download_url(url=imagenet_train_url, data_dir=data_dir)
# Download imagenet val dataset
if not os.path.exists(imagenet_val_filepath) and not os.path.exists(
- imagenet_val_download_filepath):
- logging.info('Downloading imagenet validation dataset from {}'.format(
- imagenet_val_url))
+ imagenet_val_download_filepath
+ ):
+ logging.info(
+ 'Downloading imagenet validation dataset from {}'.format(imagenet_val_url)
+ )
_download_url(url=imagenet_val_url, data_dir=data_dir)
# Download imagenet test set
@@ -501,31 +537,40 @@ def setup_imagenet_jax(data_dir):
# Setup jax dataset dir
imagenet_jax_data_dir = os.path.join(data_dir, 'jax')
- manual_download_dir = os.path.join(imagenet_jax_data_dir,
- 'downloads',
- 'manual')
+ manual_download_dir = os.path.join(
+ imagenet_jax_data_dir, 'downloads', 'manual'
+ )
os.makedirs(manual_download_dir, exist_ok=True)
# Copy tar file into jax/downloads/manual
logging.info('Checking if tar files already exists in jax/downloads/manual.')
if not os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(train_tar_file_path,
- manual_download_dir))
+ os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(train_tar_file_path, manual_download_dir)
+ )
shutil.move(train_tar_file_path, manual_download_dir)
if not os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(val_tar_file_path,
- manual_download_dir))
+ os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(val_tar_file_path, manual_download_dir)
+ )
shutil.move(val_tar_file_path, manual_download_dir)
if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')):
- logging.info('Moving imagenet_v2 to {}'.format(
- os.path.join(imagenet_jax_data_dir, 'imagenet_v2')))
- shutil.move(test_dir_path,
- os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))
+ logging.info(
+ 'Moving imagenet_v2 to {}'.format(
+ os.path.join(imagenet_jax_data_dir, 'imagenet_v2')
+ )
+ )
+ shutil.move(
+ test_dir_path, os.path.join(imagenet_jax_data_dir, 'imagenet_v2')
+ )
logging.info('Preparing imagenet data.')
ds_builder = tfds.builder(
- 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir))
+ 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir)
+ )
ds_builder.download_and_prepare()
logging.info('Set up imagenet dataset for jax framework complete')
@@ -539,14 +584,18 @@ def setup_imagenet_pytorch(data_dir):
manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual')
if not os.path.exists(train_tar_file_path):
if os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)):
- train_tar_file_path = os.path.join(manual_download_dir,
- IMAGENET_TRAIN_TAR_FILENAME)
+ os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)
+ ):
+ train_tar_file_path = os.path.join(
+ manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME
+ )
if not os.path.exists(val_tar_file_path):
if os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)):
- val_tar_file_path = os.path.join(manual_download_dir,
- IMAGENET_VAL_TAR_FILENAME)
+ os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)
+ ):
+ val_tar_file_path = os.path.join(
+ manual_download_dir, IMAGENET_VAL_TAR_FILENAME
+ )
# Setup pytorch dataset dir
imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch')
@@ -557,56 +606,68 @@ def setup_imagenet_pytorch(data_dir):
# Move tar files and imagenet_v2 into pytorch directory
if not os.path.exists(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(train_tar_file_path,
- imagenet_pytorch_data_dir))
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(train_tar_file_path, imagenet_pytorch_data_dir)
+ )
shutil.move(train_tar_file_path, imagenet_pytorch_data_dir)
if not os.path.exists(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(val_tar_file_path,
- imagenet_pytorch_data_dir))
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(val_tar_file_path, imagenet_pytorch_data_dir)
+ )
shutil.move(val_tar_file_path, imagenet_pytorch_data_dir)
if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')):
- logging.info('Moving imagenet_v2 to {}'.format(
- os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')))
- shutil.move(test_dir_path,
- os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))
+ logging.info(
+ 'Moving imagenet_v2 to {}'.format(
+ os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')
+ )
+ )
+ shutil.move(
+ test_dir_path, os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')
+ )
# Extract train data\
logging.info('Extracting imagenet train data')
extract(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME),
- os.path.join(imagenet_pytorch_data_dir, 'train'),
- mode='r:')
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME),
+ os.path.join(imagenet_pytorch_data_dir, 'train'),
+ mode='r:',
+ )
train_tar_filenames = os.listdir(
- os.path.join(imagenet_pytorch_data_dir, 'train'))
+ os.path.join(imagenet_pytorch_data_dir, 'train')
+ )
for tar_filename in train_tar_filenames:
if tar_filename.endswith('.tar'):
dir_name = tar_filename[:-4]
extract(
- os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename),
- os.path.join(imagenet_pytorch_data_dir, 'train', dir_name),
- mode='r:')
+ os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename),
+ os.path.join(imagenet_pytorch_data_dir, 'train', dir_name),
+ mode='r:',
+ )
# Extract val data
logging.info('Extracting imagenet val data')
extract(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME),
- os.path.join(imagenet_pytorch_data_dir, 'val'),
- mode='r:')
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME),
+ os.path.join(imagenet_pytorch_data_dir, 'val'),
+ mode='r:',
+ )
valprep_command = [
- 'wget',
- '-qO-',
- 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh'
+ 'wget',
+ '-qO-',
+ 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh',
]
valprep_download = subprocess.Popen(valprep_command, stdout=subprocess.PIPE)
- valprep_process = subprocess.Popen(['bash'],
- stdin=valprep_download.stdout,
- cwd=os.path.expanduser(
- os.path.join(imagenet_pytorch_data_dir,
- 'val')))
+ valprep_process = subprocess.Popen(
+ ['bash'],
+ stdin=valprep_download.stdout,
+ cwd=os.path.expanduser(os.path.join(imagenet_pytorch_data_dir, 'val')),
+ )
valprep_download.stdout.close()
valprep_process.communicate()
logging.info('Set up imagenet dataset for pytorch framework complete')
@@ -614,8 +675,8 @@ def setup_imagenet_pytorch(data_dir):
def download_imagenet_v2(data_dir):
tfds.builder(
- 'imagenet_v2/matched-frequency:3.0.0',
- data_dir=data_dir).download_and_prepare()
+ 'imagenet_v2/matched-frequency:3.0.0', data_dir=data_dir
+ ).download_and_prepare()
def download_librispeech(data_dir, tmp_dir):
@@ -634,41 +695,46 @@ def download_librispeech(data_dir, tmp_dir):
if split == 'test' and version == 'other':
continue
wget_cmd = (
- f'wget --directory-prefix={tmp_librispeech_dir} '
- f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz')
+ f'wget --directory-prefix={tmp_librispeech_dir} '
+ f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz'
+ )
subprocess.Popen(wget_cmd, shell=True).communicate()
tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz')
subprocess.Popen(
- f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}',
- shell=True).communicate()
+ f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True
+ ).communicate()
tars = [
- 'raw-metadata.tar.gz',
- 'train-clean-100.tar.gz',
- 'train-clean-360.tar.gz',
- 'train-other-500.tar.gz',
+ 'raw-metadata.tar.gz',
+ 'train-clean-100.tar.gz',
+ 'train-clean-360.tar.gz',
+ 'train-other-500.tar.gz',
]
for tar_filename in tars:
- wget_cmd = (f'wget --directory-prefix={tmp_librispeech_dir} '
- f'http://www.openslr.org/resources/12/{tar_filename}')
+ wget_cmd = (
+ f'wget --directory-prefix={tmp_librispeech_dir} '
+ f'http://www.openslr.org/resources/12/{tar_filename}'
+ )
subprocess.Popen(wget_cmd, shell=True).communicate()
tar_path = os.path.join(tmp_librispeech_dir, tar_filename)
subprocess.Popen(
- f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}',
- shell=True).communicate()
+ f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True
+ ).communicate()
tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab')
if not os.path.exists(tokenizer_vocab_path):
librispeech_tokenizer.run(
- train=True,
- input_dir=extracted_data_dir,
- tokenizer_vocab_path=tokenizer_vocab_path)
+ train=True,
+ input_dir=extracted_data_dir,
+ tokenizer_vocab_path=tokenizer_vocab_path,
+ )
librispeech_preprocess.run(
- input_dir=extracted_data_dir,
- output_dir=final_data_dir,
- tokenizer_vocab_path=tokenizer_vocab_path)
+ input_dir=extracted_data_dir,
+ output_dir=final_data_dir,
+ tokenizer_vocab_path=tokenizer_vocab_path,
+ )
def download_mnist(data_dir):
@@ -691,12 +757,14 @@ def download_wmt(data_dir):
if ds_name == 'wmt17_translate/de-en:1.0.0':
ds = dataset_builder.as_dataset(split='train', shuffle_files=False)
ds = ds.map(
- functools.partial(normalize_feature_names, dataset_builder.info),
- num_parallel_calls=tf.data.AUTOTUNE)
+ functools.partial(normalize_feature_names, dataset_builder.info),
+ num_parallel_calls=tf.data.AUTOTUNE,
+ )
# Tokenize data.
vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model')
tokenizer.train_tokenizer(
- ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7)
+ ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7
+ )
def main(_):
@@ -715,10 +783,9 @@ def main(_):
if FLAGS.all or FLAGS.criteo1tb:
logging.info('Downloading criteo1tb...')
- download_criteo1tb(data_dir,
- tmp_dir,
- num_decompression_threads,
- FLAGS.interactive_deletion)
+ download_criteo1tb(
+ data_dir, tmp_dir, num_decompression_threads, FLAGS.interactive_deletion
+ )
if FLAGS.all or FLAGS.mnist:
logging.info('Downloading MNIST...')
@@ -730,19 +797,24 @@ def main(_):
knee_singlecoil_train_url = FLAGS.fastmri_knee_singlecoil_train_url
knee_singlecoil_val_url = FLAGS.fastmri_knee_singlecoil_val_url
knee_singlecoil_test_url = FLAGS.fastmri_knee_singlecoil_test_url
- if None in (knee_singlecoil_train_url,
- knee_singlecoil_val_url,
- knee_singlecoil_test_url):
+ if None in (
+ knee_singlecoil_train_url,
+ knee_singlecoil_val_url,
+ knee_singlecoil_test_url,
+ ):
raise ValueError(
- 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url '
- 'to download the FastMRI dataset.\nSign up for the URLs at '
- 'https://fastmri.med.nyu.edu/.')
+ 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url '
+ 'to download the FastMRI dataset.\nSign up for the URLs at '
+ 'https://fastmri.med.nyu.edu/.'
+ )
if not FLAGS.skip_download:
- download_fastmri(data_dir,
- knee_singlecoil_train_url,
- knee_singlecoil_val_url,
- knee_singlecoil_test_url)
+ download_fastmri(
+ data_dir,
+ knee_singlecoil_train_url,
+ knee_singlecoil_val_url,
+ knee_singlecoil_test_url,
+ )
logging.info('fastMRI download completed. Extracting...')
setup_fastmri(data_dir)
@@ -754,12 +826,13 @@ def main(_):
imagenet_val_url = FLAGS.imagenet_val_url
if imagenet_train_url is None or imagenet_val_url is None:
raise ValueError(
- 'Must provide both --imagenet_{train,val}_url to download the '
- 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.')
+ 'Must provide both --imagenet_{train,val}_url to download the '
+ 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.'
+ )
if FLAGS.framework is None:
raise ValueError(
- 'Please specify either jax or pytorch framework through framework '
- 'flag.')
+ 'Please specify either jax or pytorch framework through framework flag.'
+ )
if not FLAGS.skip_download:
logging.info('Downloading ImageNet...')
download_imagenet(data_dir, imagenet_train_url, imagenet_val_url)
diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py
index a8c5cae1d..1c216db46 100644
--- a/datasets/librispeech_preprocess.py
+++ b/datasets/librispeech_preprocess.py
@@ -4,16 +4,15 @@
import multiprocessing.dummy
import os
-from os.path import exists
import sys
import threading
import time
-from absl import logging
import numpy as np
import pandas as pd
-from pydub import AudioSegment
import tensorflow as tf
+from absl import logging
+from pydub import AudioSegment
from datasets import librispeech_tokenizer
@@ -28,17 +27,18 @@
# taken from TFDS page for librispeech dataset :
# https://www.tensorflow.org/datasets/catalog/librispeech
librispeech_example_counts = {
- 'train-clean-100': 28539,
- 'train-clean-360': 104014,
- 'train-other-500': 148688,
- 'test-clean': 2620, # 'test-other': 2939,
- 'dev-clean': 2703,
- 'dev-other': 2864,
+ 'train-clean-100': 28539,
+ 'train-clean-360': 104014,
+ 'train-other-500': 148688,
+ 'test-clean': 2620, # 'test-other': 2939,
+ 'dev-clean': 2703,
+ 'dev-other': 2864,
}
class Counter:
"""A threadsafe counter."""
+
lock = threading.Lock()
value = 0
@@ -56,10 +56,12 @@ def report_progress(count, total, start_time):
now = time.time()
size = 50
filled = int(round(size * count / float(total)))
- percent = round(100. * count / float(total), 1)
- bar = "-" * filled + "." * (size - filled)
- sys.stdout.write("[%s] %d%% (%d of %d) %.2f sample/sec\r" %
- (bar, percent, count, total, count / (now - start_time)))
+ percent = round(100.0 * count / float(total), 1)
+ bar = '-' * filled + '.' * (size - filled)
+ sys.stdout.write(
+ '[%s] %d%% (%d of %d) %.2f sample/sec\r'
+ % (bar, percent, count, total, count / (now - start_time))
+ )
sys.stdout.flush()
@@ -72,17 +74,20 @@ def process(index):
data_folder, speaker_folder, chapter_folder = index
utterance_ids = []
- trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/'
- f'{speaker_folder}-{chapter_folder}.trans.txt')
+ trans_file = (
+ f'{data_folder}/{speaker_folder}/{chapter_folder}/'
+ f'{speaker_folder}-{chapter_folder}.trans.txt'
+ )
if not exists(trans_file):
skipped.inc()
return utterance_ids
with open(trans_file, 'r', encoding='UTF-8') as f:
- for l in f:
- utt, trans = l.strip().split(' ', maxsplit=1)
+ for line in f:
+ utt, trans = line.strip().split(' ', maxsplit=1)
audio_path = (
- f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac')
+ f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac'
+ )
if not os.path.isfile(audio_path):
skipped.inc()
@@ -105,9 +110,11 @@ def process(index):
np.save('{}/{}/{}_targets.npy'.format(out_folder, split, utt), targets)
finished.inc()
- report_progress(finished.val() + skipped.val(),
- librispeech_example_counts[split],
- start_time)
+ report_progress(
+ finished.val() + skipped.val(),
+ librispeech_example_counts[split],
+ start_time,
+ )
utterance_ids.append(utt)
return utterance_ids
@@ -126,10 +133,12 @@ def process(index):
end_time = time.time()
elapsed_time = end_time - start_time
- print(' \n time taken to preprocess split : ',
- split,
- ' = ',
- time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
+ print(
+ ' \n time taken to preprocess split : ',
+ split,
+ ' = ',
+ time.strftime('%H:%M:%S', time.gmtime(elapsed_time)),
+ )
final_count = finished.val() + skipped.val()
return pd.DataFrame(file_trans, columns=['id']), final_count
@@ -147,12 +156,12 @@ def run(input_dir, output_dir, tokenizer_vocab_path):
os.makedirs(output_dir, exist_ok=True)
subset_list = [
- 'train-clean-100',
- 'train-clean-360',
- 'train-other-500',
- 'dev-clean',
- 'dev-other',
- 'test-clean', # 'test-other',
+ 'train-clean-100',
+ 'train-clean-360',
+ 'train-other-500',
+ 'dev-clean',
+ 'dev-other',
+ 'test-clean', # 'test-other',
]
for subset in subset_list:
logging.info('Processing split = %s...', subset)
@@ -160,10 +169,14 @@ def run(input_dir, output_dir, tokenizer_vocab_path):
out_dir = os.path.join(output_dir, subset)
os.makedirs(out_dir, exist_ok=True)
example_ids, num_entries = preprocess_data(
- in_dir, output_dir, tokenizer, subset)
+ in_dir, output_dir, tokenizer, subset
+ )
if num_entries != librispeech_example_counts[subset]:
- raise ValueError('Preprocessed dataframe final count not equal to '
- 'expected count: {} vs expected {}'.format(
- num_entries, librispeech_example_counts[subset]))
+ raise ValueError(
+ 'Preprocessed dataframe final count not equal to '
+ 'expected count: {} vs expected {}'.format(
+ num_entries, librispeech_example_counts[subset]
+ )
+ )
example_ids.to_csv(os.path.join(output_dir, f'{subset}.csv'))
diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py
index 2f559752a..d566d5716 100644
--- a/datasets/librispeech_tokenizer.py
+++ b/datasets/librispeech_tokenizer.py
@@ -8,10 +8,10 @@
import tempfile
from typing import Dict
-from absl import logging
import sentencepiece as spm
import tensorflow as tf
import tensorflow_text as tftxt
+from absl import logging
gfile = tf.io.gfile
copy = tf.io.gfile.copy
@@ -24,7 +24,8 @@
def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
char_count = 0
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/ds_chars') as outfp:
+ delete=False, prefix='/tmp/ds_chars'
+ ) as outfp:
for split in splits:
data_folder = data_folder + '/' + split
for _, speaker_folder in enumerate(os.listdir(data_folder)):
@@ -32,14 +33,16 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
break
for chapter_folder in os.listdir(f'{data_folder}/{speaker_folder}'):
- trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/'
- f'{speaker_folder}-{chapter_folder}.trans.txt')
+ trans_file = (
+ f'{data_folder}/{speaker_folder}/{chapter_folder}/'
+ f'{speaker_folder}-{chapter_folder}.trans.txt'
+ )
if not exists(trans_file):
logging.info('path does not exist -> %s', trans_file)
continue
with open(trans_file, 'r', encoding='UTF-8') as f:
- for l in f:
- _, line = l.strip().split(' ', maxsplit=1)
+ for lines in f:
+ _, line = lines.strip().split(' ', maxsplit=1)
line = line + '\n'
char_count += len(line)
if char_count > maxchars:
@@ -50,13 +53,15 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
return outfp
-def train_tokenizer(data_dir: str,
- splits,
- vocab_size: int = 1024,
- model_path: str = 'spm_model.vocab',
- maxchars: int = int(1e7),
- model_type: str = 'unigram',
- character_coverage: float = 1.0):
+def train_tokenizer(
+ data_dir: str,
+ splits,
+ vocab_size: int = 1024,
+ model_path: str = 'spm_model.vocab',
+ maxchars: int = int(1e7),
+ model_type: str = 'unigram',
+ character_coverage: float = 1.0,
+):
"""Train SentencePiece tokenizer from subset of tf dataset.
Args:
@@ -77,15 +82,18 @@ def train_tokenizer(data_dir: str,
charfile = dump_chars_for_training(data_dir, splits, maxchars=maxchars)
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/sp_tmp') as model_fp:
+ delete=False, prefix='/tmp/sp_tmp'
+ ) as model_fp:
pass # we just want a prefix'd tmp-filename
- argstr = ' '.join([
+ argstr = ' '.join(
+ [
f'--input={charfile.name}',
f'--vocab_size={vocab_size}',
f'--character_coverage={character_coverage}',
f'--model_prefix={model_fp.name}',
f'--model_type={model_type}',
- ])
+ ]
+ )
spm.SentencePieceTrainer.Train(argstr)
copy_rename_path = abs_model_path + '.rntmp'
@@ -104,7 +112,8 @@ def load_tokenizer(model_filepath):
with gfile.GFile(model_filepath, 'rb') as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
- model=sp_model, add_bos=False, add_eos=True, reverse=False)
+ model=sp_model, add_bos=False, add_eos=True, reverse=False
+ )
return sp_tokenizer
@@ -123,8 +132,9 @@ def run(train, input_dir, tokenizer_vocab_path):
detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8')
logging.info('Original input = %s', test_input)
- logging.info('Output after after tokenizing and detokenizing = %s',
- detokenized)
+ logging.info(
+ 'Output after after tokenizing and detokenizing = %s', detokenized
+ )
if detokenized == test_input:
logging.info('Tokenizer working correctly!')
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 76bc5cfe0..9926b0542 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \
libreadline-dev \
libffi-dev \
curl \
- libbz2-dev \
liblzma-dev \
+ libbz2-dev \
vim
# Download and install Python 3.11
@@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs"
RUN mkdir -p data/
RUN mkdir -p experiment_runs/
-RUN pip install --upgrade pip
-
# Install Algorithmic efficiency repo
RUN pip install --upgrade pip
@@ -71,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch
RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd /algorithmic-efficiency \
- && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
- && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
+ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \
+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd /algorithmic-efficiency \
- && pip install -e '.[jax_cpu]' \
- && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \
+ && pip install -e '.[pytorch_gpu]' \
+ && pip install -e '.[jax_cpu]'; \
elif [ "$framework" = "both" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd /algorithmic-efficiency \
- && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
- && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \
+ && pip install -e '.[pytorch_gpu]' \
+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \
diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh
index 645b81955..6b5e67ceb 100644
--- a/docker/build_docker_images.sh
+++ b/docker/build_docker_images.sh
@@ -1,27 +1,40 @@
#!/bin/bash
# Bash script to build and push dev docker images to artifact repo
# Usage:
-# bash build_docker_images.sh -b
+# bash build_docker_images.sh -b -f
# Make program exit with non-zero exit code if any command fails.
set -e
-while getopts b: flag
+while getopts "b:p:f:" flag;
do
case "${flag}" in
b) GIT_BRANCH=${OPTARG};;
+ p) PROJECT=${OPTARG};;
+ f) FRAMEWORK=${OPTARG};;
esac
done
# Artifact repostiory
-ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
+if [ "$PROJECT" = "mlcommons-algoperf" ]; then
+ ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
+else
+ ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo"
+fi
-if [[ -z ${GIT_BRANCH+x} ]]
+if [[ -z ${GIT_BRANCH+x} ]];
then
GIT_BRANCH='main' # Set default argument
fi
-for FRAMEWORK in "jax" "pytorch" "both"
+FRAMEWORKS=( "jax" "pythorch" "both" )
+
+if [[ -n "$FRAMEWORK" ]];
+then
+ FRAMEWORKS=("$FRAMEWORK")
+fi
+
+for FRAMEWORK in "${FRAMEWORKS[@]}";
do
IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}"
DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH"
diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py
index 48c521009..a816eb5c2 100644
--- a/docker/scripts/singularity_converter.py
+++ b/docker/scripts/singularity_converter.py
@@ -15,26 +15,27 @@
from spython.main.parse.writers import get_writer
# globals
-ENTRY_POINT = "/bin/bash" # seems to be a good default
+ENTRY_POINT = '/bin/bash' # seems to be a good default
FORCE = False # seems to be a good default
#
-parser = argparse.ArgumentParser(description="Custom Singularity converter")
+parser = argparse.ArgumentParser(description='Custom Singularity converter')
parser.add_argument(
- "-i", "--input", type=str, help="Docker input path", default="Dockerfile")
+ '-i', '--input', type=str, help='Docker input path', default='Dockerfile'
+)
parser.add_argument(
- "-o",
- "--output",
- type=str,
- help="Singularity output path",
- default="Singularity.def",
+ '-o',
+ '--output',
+ type=str,
+ help='Singularity output path',
+ default='Singularity.def',
)
args = parser.parse_args()
INPUT_DOCKERFILE_PATH = args.input
OUTPUT_SINGULARITY_PATH = args.output
# create Docker parser and Singularity writer
-parser = get_parser("docker")
-writer = get_writer("singularity")
+parser = get_parser('docker')
+writer = get_writer('singularity')
# parse Dockerfile into Singularity and suppress %files commands
recipeParser = parser(INPUT_DOCKERFILE_PATH)
@@ -44,5 +45,5 @@
# convert to string and save to output file
result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE)
-with open(OUTPUT_SINGULARITY_PATH, "w") as f:
+with open(OUTPUT_SINGULARITY_PATH, 'w') as f:
f.write(result)
diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md
index 7778030dc..e9918e14c 100644
--- a/docs/CONTRIBUTING.md
+++ b/docs/CONTRIBUTING.md
@@ -179,13 +179,13 @@ docker run -t -d \
To find the container IDs of running containers
```bash
-docker ps
+docker ps
```
To see the logging output
```bash
-docker logs
+docker logs
```
To enter a bash session in the container
@@ -209,7 +209,7 @@ docker run -t -d \
--gpus all \
--ipc=host \
\
---keep_container_alive true
+--keep_container_alive true
```
## Submitting PRs
@@ -222,38 +222,26 @@ We run tests with GitHub Actions, configured in the [.github/workflows](.github/
### Style Testing
-We run yapf and linting tests on PRs. You can view and fix offending errors with these instructions.
-
+We run formatting and linting tests via ruff on PRs. You can view and fix offending errors with these instructions.
To run the below commands, use the versions installed via `pip install -e '.[dev]'`.
-To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!):
+To check whether your code is **formatted** correctly, run the following:
```bash
-yapf -i -r -vv -p algoperf datasets prize_qualification_baselines reference_algorithms tests *.py
+ruff format --check
```
-To sort all import orderings, run the following:
+To automatically fix formatting errors you can run `ruff format`, without the `--check` flag.
+(**WARNING**: this will edit your code, so it is suggested to make a git commit first!)
-```bash
-isort .
-```
-
-To just print out all offending import orderings, run the following:
+To check whether your code is **linted** correctly, run the following:
```bash
-isort . --check --diff
+ruff check
```
-To print out all offending pylint issues, run the following:
-
-```bash
-pylint algoperf
-pylint datasets
-pylint prize_qualification_baselines
-pylint reference_algorithms
-pylint submission_runner.py
-pylint tests
-```
+To automatically fix linting errors you can run `ruff check --fix`, with the additional `--fix` flag.
+(**WARNING**: this will edit your code, so it is suggested to make a git commit first!)
### Unit and Integration Tests
@@ -270,9 +258,9 @@ To run a regression test:
1. Build and upload latest Docker images from dev branch.
- ```bash
- bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev
- ```
+ ```bash
+ bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev
+ ```
2. Turn on the self-hosted runner.
3. Run the self-hosted runner application for the runner to accept jobs.
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
index c451a18ac..62161b3d5 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
@@ -1,26 +1,24 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
-# isort: on
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -30,15 +28,14 @@
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -73,19 +70,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -124,7 +124,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -132,6 +133,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -140,7 +142,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -156,11 +159,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -170,101 +175,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -281,37 +300,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -351,14 +376,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
index b8ac10f33..9752aef33 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
@@ -1,26 +1,24 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
-# isort: on
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -30,15 +28,14 @@
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -73,19 +70,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -124,7 +124,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -132,6 +133,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -140,7 +142,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -156,11 +159,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -170,101 +175,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -281,37 +300,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -351,14 +376,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
index a2f9fb4c5..f6c2faa9d 100644
--- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
@@ -3,13 +3,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -21,33 +19,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -59,7 +54,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -67,7 +65,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -76,9 +75,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -107,51 +106,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -189,54 +194,59 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -248,26 +258,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -280,7 +294,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -289,31 +304,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -353,14 +375,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
index a37b0d341..68ff30b2a 100644
--- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
@@ -3,13 +3,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -21,33 +19,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -59,7 +54,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -67,7 +65,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -76,9 +75,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -107,51 +106,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -189,54 +194,59 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -248,26 +258,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -280,7 +294,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -289,31 +304,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -353,14 +375,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json
index 199f77041..ad68372c6 100644
--- a/prize_qualification_baselines/external_tuning/tuning_search_space.json
+++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json
@@ -1,53 +1,47 @@
[
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.1,
- "learning_rate": 0.001308209823469072,
- "one_minus_beta1": 0.02686663061,
- "beta2": 0.9981232922116359,
- "weight_decay": 0.16375311233774334,
- "warmup_factor": 0.1
- },
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.2,
- "learning_rate": 0.0008445074561975979,
- "one_minus_beta1": 0.11042418465,
- "beta2": 0.9978504782314613,
- "weight_decay": 0.08135402759553023,
- "warmup_factor": 0.05
- },
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.0,
- "learning_rate": 0.001308209823469072,
- "one_minus_beta1": 0.02686663061,
- "beta2": 0.9981232922116359,
- "weight_decay": 0.16375311233774334,
- "warmup_factor": 0.1
- },
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.0,
- "learning_rate": 0.004958460849689891,
- "one_minus_beta1": 0.13625575743,
- "beta2": 0.6291854735396584,
- "weight_decay": 0.1147386261512052,
- "warmup_factor": 0.02
- },
- {
- "dropout_rate": 0.1,
- "label_smoothing": 0.0,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
- }
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.1,
+ "learning_rate": 0.001308209823469072,
+ "one_minus_beta1": 0.02686663061,
+ "beta2": 0.9981232922116359,
+ "weight_decay": 0.16375311233774334,
+ "warmup_factor": 0.1
+ },
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.2,
+ "learning_rate": 0.0008445074561975979,
+ "one_minus_beta1": 0.11042418465,
+ "beta2": 0.9978504782314613,
+ "weight_decay": 0.08135402759553023,
+ "warmup_factor": 0.05
+ },
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.0,
+ "learning_rate": 0.001308209823469072,
+ "one_minus_beta1": 0.02686663061,
+ "beta2": 0.9981232922116359,
+ "weight_decay": 0.16375311233774334,
+ "warmup_factor": 0.1
+ },
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.0,
+ "learning_rate": 0.004958460849689891,
+ "one_minus_beta1": 0.13625575743,
+ "beta2": 0.6291854735396584,
+ "weight_decay": 0.1147386261512052,
+ "warmup_factor": 0.02
+ },
+ {
+ "dropout_rate": 0.1,
+ "label_smoothing": 0.0,
+ "learning_rate": 0.0017486387539278373,
+ "one_minus_beta1": 0.06733926164,
+ "beta2": 0.9955159689799007,
+ "weight_decay": 0.08121616522670176,
+ "warmup_factor": 0.02
+ }
]
-
-
-
-
-
-
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
index 78c3b5b3e..f61a7bdcd 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
@@ -1,53 +1,50 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
-# isort: on
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
_GRAD_CLIP_EPS = 1e-6
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -82,19 +79,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -133,7 +133,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -141,6 +142,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -149,7 +151,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -165,11 +168,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -182,101 +187,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters['warmup_factor'] * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters['learning_rate'],
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters['learning_rate'],
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps)
+ init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters['one_minus_beta1'],
- b2=hyperparameters['beta2'],
- eps=1e-8,
- weight_decay=hyperparameters['weight_decay'])
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters['one_minus_beta1'],
+ b2=hyperparameters['beta2'],
+ eps=1e-8,
+ weight_decay=hyperparameters['weight_decay'],
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -296,37 +315,43 @@ def update_params(
grad_clip = hyperparameters['grad_clip']
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -366,14 +391,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
index ffe854a0e..130ebdabe 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
@@ -1,53 +1,50 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
-# isort: on
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
_GRAD_CLIP_EPS = 1e-6
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -82,19 +79,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -133,7 +133,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -141,6 +142,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -149,7 +151,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -165,11 +168,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -182,101 +187,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters['warmup_factor'] * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters['learning_rate'],
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters['learning_rate'],
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps)
+ init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters['one_minus_beta1'],
- b2=hyperparameters['beta2'],
- eps=1e-8,
- weight_decay=hyperparameters['weight_decay'])
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters['one_minus_beta1'],
+ b2=hyperparameters['beta2'],
+ eps=1e-8,
+ weight_decay=hyperparameters['weight_decay'],
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -296,37 +315,43 @@ def update_params(
grad_clip = hyperparameters['grad_clip']
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -366,14 +391,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
index 554a28762..8e8adbeaa 100644
--- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
@@ -4,13 +4,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -18,12 +16,12 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
@@ -32,33 +30,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -70,7 +65,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -78,7 +76,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -87,9 +86,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -118,51 +117,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -200,11 +205,13 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
@@ -213,44 +220,47 @@ def init_optimizer_state(workload: spec.Workload,
hyperparameters = HPARAMS
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -265,26 +275,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -297,7 +311,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -306,31 +321,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -370,14 +392,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
index e4317fa18..5d7c444a4 100644
--- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
@@ -4,13 +4,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -18,12 +16,12 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
@@ -32,33 +30,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -70,7 +65,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -78,7 +76,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -87,9 +86,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -118,51 +117,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -200,11 +205,13 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
@@ -213,44 +220,47 @@ def init_optimizer_state(workload: spec.Workload,
hyperparameters = HPARAMS
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -265,26 +275,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -297,7 +311,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -306,31 +321,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -370,14 +392,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/pyproject.toml b/pyproject.toml
index 4e15e4400..b6cfa42cd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,9 +27,9 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
@@ -46,7 +46,6 @@ dependencies = [
"clu==0.0.12",
"matplotlib>=3.9.2",
"tabulate==0.9.0",
-
]
[build-system]
@@ -69,19 +68,11 @@ version_file = "algoperf/_version.py"
###############################################################################
[project.optional-dependencies]
# All workloads
-full = [
- "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]",
-]
+full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"]
# All workloads plus development dependencies
full_dev = ["algoperf[full,dev]"]
# Dependencies for developing the package
-dev = [
- "isort==5.12.0",
- "pylint==2.17.4",
- "pytest==8.3.3",
- "yapf==0.32.0",
- "pre-commit==4.0.1",
-]
+dev = ["ruff==0.12.0", "pytest==8.3.3", "pre-commit==4.0.1"]
wandb = ["wandb==0.19.6"]
@@ -104,11 +95,7 @@ jax_core_deps = [
"ml_dtypes==0.4.1",
"protobuf==4.25.5",
]
-jax_cpu = [
- "jax==0.4.28",
- "jaxlib==0.4.28",
- "algoperf[jax_core_deps]",
-]
+jax_cpu = ["jax==0.4.28", "jaxlib==0.4.28", "algoperf[jax_core_deps]"]
jax_gpu = [
"jax==0.4.28",
"jaxlib==0.4.28",
@@ -117,217 +104,75 @@ jax_gpu = [
"algoperf[jax_core_deps]",
]
pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]
-pytorch_gpu = [
- "torch==2.5.1",
- "torchvision==0.20.1",
-] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
+pytorch_gpu = ["torch==2.5.1", "torchvision==0.20.1"]
+# Note: omit the cuda suffix and installing from the appropriate wheel
+# will result in using locally installed CUDA.
###############################################################################
-# Linting Configurations #
+# Linting & Formatting Configurations #
###############################################################################
-
-# yapf configuration
-[tool.yapf]
-based_on_style = "yapf"
-each_dict_entry_on_separate_line = false
-split_all_top_level_comma_separated_values = true
-[tool.yapfignore]
-ignore_patterns = ["algoperf/_version.py"]
-
-# isort configuration
-[tool.isort]
-profile = "google"
-
-# pylint configuration
-[tool.pylint.MASTER]
-persistent = false
-ignore = "get_references_web.py,get_references_web_single_group.py,_version.py"
-
-[tool.pylint.REPORTS]
-reports = false
-msg-template = "{msg_id}:{line:3} {obj}: {msg} [{symbol}]"
-
-[tool.pylint.MESSAGES_CONTROL]
-enable = "indexing-exception,old-raise-syntax"
-
-[tool.pylint.BASIC]
-# Required attributes for module, separated by a comma
-#required-attributes=
-# Regular expression which should only match the name
-# of functions or classes which do not require a docstring.
-no-docstring-rgx = "(__.*__|main)"
-# Min length in lines of a function that requires a docstring.
-docstring-min-length = 10
-# Regular expression which should only match correct module names. The
-# leading underscore is sanctioned for private modules by Google's style
-# guide.
-#
-# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover
-# requirements of Python's module system.
-module-rgx = "^(_?[a-z][a-z0-9_]*)|__init__$"
-# Regular expression which should only match correct module level names
-const-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$"
-# Regular expression which should only match correct class attribute
-class-attribute-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$"
-# Regular expression which should only match correct class names
-class-rgx = "^_?[A-Z][a-zA-Z0-9]*$"
-# Regular expression which should only match correct function names.
-# 'camel_case' and 'snake_case' group names are used for consistency of naming
-# styles across functions and methods.
-function-rgx = "^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$"
-# Regular expression which should only match correct method names.
-# 'camel_case' and 'snake_case' group names are used for consistency of naming
-# styles across functions and methods. 'exempt' indicates a name which is
-# consistent with all naming styles.
-method-rgx = "(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass|test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|(?:test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$"
-# Regular expression which should only match correct instance attribute names
-attr-rgx = "^_{0,2}[a-z][a-z0-9_]*$"
-# Regular expression which should only match correct argument names
-argument-rgx = "^[a-z][a-z0-9_]*$"
-# Regular expression which should only match correct variable names
-variable-rgx = "^[a-z][a-z0-9_]*$"
-# Regular expression which should only match correct list comprehension /
-# generator expression variable names
-inlinevar-rgx = "^[a-z][a-z0-9_]*$"
-# Good variable names which should always be accepted, separated by a comma
-good-names = "main,_"
-# Bad variable names which should always be refused, separated by a comma
-bad-names = ""
-# List of builtins function names that should not be used, separated by a comma
-#bad-functions=input,apply,reduce
-# List of decorators that define properties, such as abc.abstractproperty.
-property-classes = "abc.abstractproperty"
-
-[tool.pylint.typecheck]
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members = true
-
-# List of decorators that create context managers from functions, such as
-# contextlib.contextmanager.
-contextmanager-decorators = [
- "contextlib.contextmanager",
- "contextlib2.contextmanager",
-]
-
-[tool.pylint.VARIABLES]
-# Tells whether we should check for unused import in __init__ files.
-init-import = false
-
-# A regular expression matching names used for dummy variables (i.e. not used).
-dummy-variables-rgx = "^\\*{0,2}(_$|unused_|dummy_)"
-
-# List of additional names supposed to be defined in builtins.
-additional-builtins = []
-
-[tool.pylint.CLASSES]
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods = ["__init__", "__new__", "setUp"]
-
-# Valid names for the first argument to a class method.
-valid-classmethod-first-arg = ["cls", "class_"]
-
-[tool.pylint.EXCEPTIONS]
-overgeneral-exceptions = [
- "builtins.StandardError",
- "builtins.Exception",
- "builtins.BaseException",
-]
-
-[tool.pylint.IMPORTS]
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec", "sets"]
-
-[tool.pylint.FORMAT]
-# List of checkers and warnings to disable.
-disable = [
- "abstract-method",
- "access-member-before-definition",
- "arguments-differ",
- "assignment-from-no-return",
- "attribute-defined-outside-init",
- "bad-mcs-classmethod-argument",
- "bad-option-value",
- "c-extension-no-member",
- "consider-merging-isinstance",
- "consider-using-dict-comprehension",
- "consider-using-enumerate",
- "consider-using-in",
- "consider-using-set-comprehension",
- "consider-using-ternary",
- "deprecated-method",
- "design",
- "file-ignored",
- "fixme",
- "global-statement",
- "import-error",
- "inconsistent-return-statements",
- "invalid-unary-operand-type",
- "len-as-condition",
- "locally-disabled",
- "locally-enabled",
- "misplaced-comparison-constant",
- "missing-docstring",
- "multiple-imports",
- "no-else-return",
- "no-member",
- "no-name-in-module",
- "no-self-use",
- "no-value-for-parameter",
- "not-an-iterable",
- "not-context-manager",
- "pointless-except",
- "protected-access",
- "redefined-argument-from-local",
- "signature-differs",
- "similarities",
- "simplifiable-if-expression",
- "star-args",
- "super-init-not-called",
- "suppressed-message",
- "too-many-function-args",
- "trailing-comma-tuple",
- "trailing-newlines",
- "ungrouped-imports",
- "unnecessary-pass",
- "unsubscriptable-object",
- "unused-argument",
- "useless-object-inheritance",
- "useless-return",
- "useless-suppression",
- "wrong-import-order",
- "wrong-import-position",
- "unneeded-not",
- "unexpected-keyword-arg",
- "redundant-keyword-arg",
- "unspecified-encoding",
- "logging-fstring-interpolation",
- "consider-using-f-string",
- "use-dict-literal",
+[tool.ruff]
+line-length = 80
+indent-width = 2
+exclude = ["_version.py"]
+target-version = "py311"
+
+[tool.ruff.format]
+quote-style = "single"
+
+[tool.ruff.lint]
+# Could add the commented out rules in the future:
+extend-select = [
+ "BLE", # disallow catch-all exceptions
+ "COM", # enforce trailing comma rules
+ "F", # Pyflakes rules
+ "FA", # Enforce from __future__ import annotations
+ "I", # Isort rules
+ "ICN", # Use common import conventions
+ "PLE", # Pylint Errors
+ "TID", # Some good import practices
+ # "A", # flake8-builtins: detect shadowed builtins
+ # "B", # flake8-bugbear:
+ # "C4", # flake8-comprehensions: catch incorrect use of comprehensions
+ # "D", # pydocstyle
+ # "DOC", # pydoclint
+ # "DTZ", # flake8-datetimez: strict timezone manipulation with datetime
+ # "E", # pycodestyle errors
+ # "FBT", # flake8-boolean-trap: detect boolean traps
+ # "ISC", # flake8-implicit-str-concat: good use of string concatenation
+ # "N", # pep8-naming: enforce naming conventions
+ # "NPY", # Some numpy-specific things
+ # "PL", # All Pylint rules
+ # "PLC", # Pylint Convention
+ # "PLR", # Pylint Refactor
+ # "PLW", # Pylint Warnings
+ # "PTH", # flake8-use-pathlib: use pathlib instead of os.path
+ # "RET", # flake8-return: good return practices
+ # "S", # flake8-bandit: security testing
+ # "SIM", # flake8-simplify: common simplification rules
+ # "TC", # flake8-type-checking: enforce importing certain types in a TYPE_CHECKING block
+ # "TD", # flake8-todo: Be diligent with TODO comments
+ # "UP", # pyupgrade: Warn if things can changed due to newer versions
+ # "W", # pycodestyle warnings
]
-# Maximum number of characters on a single line.
-max-line-length = 80
-ignore-long-lines = "(?x)(^\\s*(import|from)\\s|^\\s*(\\#\\ )?(https?|ftp):\\/\\/[^\\s\\/$.?#].[^\\s]*>?$|^[a-zA-Z_][a-zA-Z0-9_]*\\s*=\\s*('[^']\\S+'|\"[^\"]\\S+\"))"
-# Maximum number of lines in a module
-max-module-lines = 99999
-# String used as indentation unit. We differ from PEP8's normal 4 spaces.
-indent-string = ' '
-single-line-if-stmt = true
-# Do not warn about multiple statements on a single line for constructs like
-# if test: stmt
-[tool.pylint.LOGGING]
-logging-modules = "logging,absl.logging"
-# Add logging modules.
-[tool.pylint.MISCELLANEOUS]
-# Maximum line length for lambdas
-#short-func-length=1
-# List of module members that should be marked as deprecated.
-# All of the string functions are listed in 4.1.4 Deprecated string functions
-# in the Python 2.4 docs.
-#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint
-# List of exceptions that do not need to be mentioned in the Raises section of
-# a docstring.
-#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError
-# Number of spaces of indent required when the last token on the preceding line
-# is an open (, [, or {.
-indent-after-paren = 4
+ignore = [
+ # Conflicting lint rules with Ruff's formatter
+ # (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules).
+ "W191",
+ "E111",
+ "E114",
+ "E117",
+ "D206",
+ "D300",
+ "Q000",
+ "Q001",
+ "Q002",
+ "Q003",
+ "COM812",
+ "COM819",
+ "ISC001",
+ "ISC002",
+ "FBT001",
+ "FBT003",
+ "TD003",
+]
\ No newline at end of file
diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
index 3d8e35eaa..d080f2fb3 100644
--- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
+++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -19,26 +19,29 @@ def get_batch_size(workload_name):
def cosine_decay(lr, step, total_steps):
- ratio = jnp.maximum(0., step / total_steps)
- mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio))
+ ratio = jnp.maximum(0.0, step / total_steps)
+ mult = 0.5 * (1.0 + jnp.cos(jnp.pi * ratio))
return mult * lr
-def create_learning_rate_fn(hparams: spec.Hyperparameters,
- steps_per_epoch: int):
+def create_learning_rate_fn(
+ hparams: spec.Hyperparameters, steps_per_epoch: int
+):
"""Create learning rate schedule."""
- base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128.
+ base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128.0
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=base_learning_rate,
- transition_steps=hparams.warmup_epochs * steps_per_epoch)
+ init_value=0.0,
+ end_value=base_learning_rate,
+ transition_steps=hparams.warmup_epochs * steps_per_epoch,
+ )
cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=base_learning_rate,
- decay_steps=cosine_epochs * steps_per_epoch)
+ init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn],
- boundaries=[hparams.warmup_epochs * steps_per_epoch])
+ schedules=[warmup_fn, cosine_fn],
+ boundaries=[hparams.warmup_epochs * steps_per_epoch],
+ )
return schedule_fn
@@ -46,51 +49,59 @@ def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int):
steps_per_epoch = num_train_examples // get_batch_size('cifar')
learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch)
opt_init_fn, opt_update_fn = optax.sgd(
- nesterov=True,
- momentum=hyperparameters.momentum,
- learning_rate=learning_rate_fn)
+ nesterov=True,
+ momentum=hyperparameters.momentum,
+ learning_rate=learning_rate_fn,
+ )
return opt_init_fn, opt_update_fn
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del model_params
del model_state
del rng
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
- opt_init_fn, opt_update_fn = optimizer(hyperparameters,
- workload.num_train_examples)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
+ opt_init_fn, opt_update_fn = optimizer(
+ hyperparameters, workload.num_train_examples
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, None, 0, 0),
- static_broadcasted_argnums=(0, 1))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- hyperparameters,
- batch,
- rng):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, None, 0, 0),
+ static_broadcasted_argnums=(0, 1),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ hyperparameters,
+ batch,
+ rng,
+):
def _loss_fn(params):
"""loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(batch['targets'], logits)
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
weight_penalty_params = jax.tree_util.tree_leaves(params)
@@ -102,25 +113,27 @@ def _loss_fn(params):
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(_, new_model_state), grad = grad_fn(current_param_container)
grad = lax.pmean(grad, axis_name='batch')
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -130,21 +143,30 @@ def update_params(
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
new_optimizer_state, new_params, new_model_state = pmapped_train_step(
- workload, opt_update_fn, model_state, optimizer_state,
- current_param_container, hyperparameters, batch, per_device_rngs)
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ hyperparameters,
+ batch,
+ per_device_rngs,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -158,14 +180,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
index d8b91f83a..e8080fe34 100644
--- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
+++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
@@ -3,9 +3,7 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
@@ -16,56 +14,63 @@ def get_batch_size(workload_name):
return batch_sizes[workload_name]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del workload
del model_state
del rng
- base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128.
+ base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128.0
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=base_lr,
- momentum=hyperparameters.momentum,
- weight_decay=hyperparameters.l2),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=base_lr,
+ momentum=hyperparameters.momentum,
+ weight_decay=hyperparameters.l2,
+ ),
}
scheduler1 = LinearLR(
- optimizer_state['optimizer'],
- start_factor=1e-5,
- end_factor=1.,
- total_iters=hyperparameters.warmup_epochs)
+ optimizer_state['optimizer'],
+ start_factor=1e-5,
+ end_factor=1.0,
+ total_iters=hyperparameters.warmup_epochs,
+ )
cosine_epochs = max(
- hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1)
+ hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1
+ )
scheduler2 = CosineAnnealingLR(
- optimizer_state['optimizer'], T_max=cosine_epochs)
+ optimizer_state['optimizer'], T_max=cosine_epochs
+ )
optimizer_state['scheduler'] = SequentialLR(
- optimizer_state['optimizer'],
- schedulers=[scheduler1, scheduler2],
- milestones=[hyperparameters.warmup_epochs])
+ optimizer_state['optimizer'],
+ schedulers=[scheduler1, scheduler2],
+ milestones=[hyperparameters.warmup_epochs],
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
@@ -78,15 +83,17 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'], logits_batch=logits_batch)
+ label_batch=batch['targets'], logits_batch=logits_batch
+ )
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
loss.backward()
@@ -99,16 +106,18 @@ def update_params(
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -122,14 +131,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json
index 283341705..aa8fcacfd 100644
--- a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json
+++ b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json
@@ -1,7 +1,7 @@
{
- "learning_rate": {"feasible_points": [0.1]},
- "warmup_epochs": {"feasible_points": [5]},
- "num_epochs": {"feasible_points": [200]},
- "l2": {"feasible_points": [5e-4]},
- "momentum": {"feasible_points": [0.9]}
+ "learning_rate": { "feasible_points": [0.1] },
+ "warmup_epochs": { "feasible_points": [5] },
+ "num_epochs": { "feasible_points": [200] },
+ "l2": { "feasible_points": [5e-4] },
+ "momentum": { "feasible_points": [0.9] }
}
diff --git a/reference_algorithms/development_algorithms/mnist/discrete_space.json b/reference_algorithms/development_algorithms/mnist/discrete_space.json
index 310f19e7d..8056d4861 100644
--- a/reference_algorithms/development_algorithms/mnist/discrete_space.json
+++ b/reference_algorithms/development_algorithms/mnist/discrete_space.json
@@ -1,17 +1,17 @@
[
- {
- "learning_rate": 1e-3,
- "one_minus_beta_1": 0.999,
- "epsilon": 0.9
- },
- {
- "learning_rate": 1e-2,
- "one_minus_beta_1": 0.99,
- "epsilon": 0.99
- },
- {
- "learning_rate": 1e-1,
- "one_minus_beta_1": 0.9,
- "epsilon": 0.999
- }
-]
\ No newline at end of file
+ {
+ "learning_rate": 1e-3,
+ "one_minus_beta_1": 0.999,
+ "epsilon": 0.9
+ },
+ {
+ "learning_rate": 1e-2,
+ "one_minus_beta_1": 0.99,
+ "epsilon": 0.99
+ },
+ {
+ "learning_rate": 1e-1,
+ "one_minus_beta_1": 0.9,
+ "epsilon": 0.999
+ }
+]
diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
index c1f54597d..afdd1bd43 100644
--- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
+++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -18,50 +18,59 @@ def get_batch_size(workload_name):
return batch_sizes[workload_name]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del model_params
del model_state
del rng
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = optax.chain(
- optax.scale_by_adam(
- b1=1.0 - hyperparameters.one_minus_beta_1,
- b2=0.999,
- eps=hyperparameters.epsilon),
- optax.scale(-hyperparameters.learning_rate))
+ optax.scale_by_adam(
+ b1=1.0 - hyperparameters.one_minus_beta_1,
+ b2=0.999,
+ eps=hyperparameters.epsilon,
+ ),
+ optax.scale(-hyperparameters.learning_rate),
+ )
return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn
# We need to jax.pmap here instead of inside update_params because the latter
# would recompile the function every step.
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, None, 0, 0, 0),
- static_broadcasted_argnums=(0, 1))
-def pmapped_update_params(workload: spec.Workload,
- opt_update_fn,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- optimizer_state: spec.OptimizerState,
- rng: spec.RandomState) -> spec.UpdateReturn:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, None, 0, 0, 0),
+ static_broadcasted_argnums=(0, 1),
+)
+def pmapped_update_params(
+ workload: spec.Workload,
+ opt_update_fn,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ optimizer_state: spec.OptimizerState,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
del hyperparameters
def loss_fn(params):
logits_batch, new_model_state = workload.model_fn(
- params=params,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=params,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(batch['targets'], logits_batch)
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
return loss, new_model_state
@@ -69,25 +78,27 @@ def loss_fn(params):
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, new_model_state), grad = grad_fn(current_param_container)
grad = lax.pmean(grad, axis_name='batch')
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -98,27 +109,30 @@ def update_params(
per_device_rngs = jax.random.split(rng, jax.local_device_count())
optimizer_state, opt_update_fn = optimizer_state
new_optimizer_state, updated_params, new_model_state = pmapped_update_params(
- workload,
- opt_update_fn,
- current_param_container,
- model_state,
- hyperparameters,
- batch,
- optimizer_state,
- per_device_rngs)
+ workload,
+ opt_update_fn,
+ current_param_container,
+ model_state,
+ hyperparameters,
+ batch,
+ optimizer_state,
+ per_device_rngs,
+ )
return (new_optimizer_state, opt_update_fn), updated_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -132,14 +146,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py
index dedd96793..9940fca6e 100644
--- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py
+++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py
@@ -13,38 +13,41 @@ def get_batch_size(workload_name):
return batch_sizes[workload_name]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del model_state
del workload
del rng
optimizer_state = {
- 'optimizer':
- torch.optim.Adam(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999),
- eps=hyperparameters.epsilon),
+ 'optimizer': torch.optim.Adam(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999),
+ eps=hyperparameters.epsilon,
+ ),
}
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del hyperparameters
del loss_type
@@ -59,15 +62,17 @@ def update_params(
param.grad = None
output, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'], logits_batch=output)
+ label_batch=batch['targets'], logits_batch=output
+ )
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
loss.backward()
optimizer_state['optimizer'].step()
@@ -75,16 +80,18 @@ def update_params(
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -98,14 +105,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json
index 35b941133..bf7d3f1c1 100644
--- a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json
+++ b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json
@@ -1,5 +1,5 @@
{
- "learning_rate": {"min": 1e-4, "max": 1e-2, "scaling": "log"},
- "one_minus_beta_1": {"min": 0.9, "max": 0.999, "scaling": "log"},
- "epsilon": {"feasible_points": [1e-8, 1e-5, 1e-3]}
+ "learning_rate": { "min": 1e-4, "max": 1e-2, "scaling": "log" },
+ "one_minus_beta_1": { "min": 0.9, "max": 0.999, "scaling": "log" },
+ "epsilon": { "feasible_points": [1e-8, 1e-5, 1e-3] }
}
diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
index ff98464ae..32ba97da4 100644
--- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
+++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
@@ -30,16 +30,17 @@
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import jax
-from jax import numpy as jnp
import optax
+from jax import numpy as jnp
JTensor = Any
NestedJTensor = Any
NestedHParams = Any
-def to_quantized(fvalue: JTensor,
- quantized_dtype: jnp.dtype) -> Tuple[JTensor, JTensor]:
+def to_quantized(
+ fvalue: JTensor, quantized_dtype: jnp.dtype
+) -> Tuple[JTensor, JTensor]:
"""Converts floating point values `fvalues` to quantized values.
We use a very simple quantization scheme where the range is symmetric around
@@ -82,16 +83,17 @@ def to_quantized(fvalue: JTensor,
# We first decide the scale.
if fvalue.ndim < 1:
raise ValueError(
- f'Input array {fvalue} must have a strictly positive number of '
- 'dimensions.')
+ f'Input array {fvalue} must have a strictly positive number of '
+ 'dimensions.'
+ )
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
bucket_size = max_abs / num_buckets
bs_expanded = bucket_size[jnp.newaxis, ...]
# To avoid divide by 0.0
- bs_nonzero = jnp.where(bs_expanded > 0.0,
- bs_expanded,
- jnp.ones_like(bs_expanded))
+ bs_nonzero = jnp.where(
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
+ )
ratio = fvalue / bs_nonzero
# We use rounding to remove bias.
quantized = jnp.round(ratio)
@@ -128,8 +130,8 @@ def adafactor_decay_rate_adam(beta2: float, step_counter: JTensor) -> JTensor:
"""
step = step_counter
beta2 = jnp.array(beta2, dtype=jnp.float32)
- t = step + 1.
- return beta2 * (1. - jnp.power(beta2, t - 1.)) / (1. - jnp.power(beta2, t))
+ t = step + 1.0
+ return beta2 * (1.0 - jnp.power(beta2, t - 1.0)) / (1.0 - jnp.power(beta2, t))
def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor:
@@ -145,7 +147,7 @@ def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor:
"""
step = step_counter
exponent = jnp.array(exponent, dtype=jnp.float32)
- return 1. - jnp.power((step + 1.), -exponent)
+ return 1.0 - jnp.power((step + 1.0), -exponent)
def reduce_mean(array: JTensor) -> JTensor:
@@ -187,6 +189,7 @@ def reduce_rms(array: JTensor) -> JTensor:
@dataclasses.dataclass(frozen=True)
class _ShardedAdafactorUpdateResult:
"""Structure containing per-variable info for Adafactor."""
+
update: Optional[Any]
m: Optional[Any]
m_scale: Optional[Any]
@@ -197,6 +200,7 @@ class _ShardedAdafactorUpdateResult:
class ShardedAdafactorState(NamedTuple):
"""Overall state of the ShardedAdafactor optimizer."""
+
count: JTensor
m: Optional[NestedJTensor]
m_scale: Optional[NestedJTensor]
@@ -208,27 +212,29 @@ class ShardedAdafactorState(NamedTuple):
class _ShardedAdafactorHelper:
"""Helper class to implement optax-based sharded Adafactor."""
- def __init__(self,
- learning_rate: optax.Schedule,
- weight_decay: Optional[float],
- layerwise_adaptation: bool,
- decay_method: str,
- decay_adam: float,
- decay_pow: float,
- beta1: float,
- clip_threshold: Optional[float],
- factored: bool,
- epsilon1_grad_sq_reg: float,
- quantized_dtype: jnp.dtype,
- respect_skip_lp_regularization: bool,
- exclude_from_layerwise_adaptation: Optional[List[str]],
- per_var_learning_summary: bool,
- sort_factored_second_moment_dims: bool,
- min_dim_size_to_factor: int,
- multiply_by_parameter_scale: bool,
- epsilon2_param_scale_reg: float,
- maybe_inf_to_nan: bool,
- nesterov: bool) -> None:
+ def __init__(
+ self,
+ learning_rate: optax.Schedule,
+ weight_decay: Optional[float],
+ layerwise_adaptation: bool,
+ decay_method: str,
+ decay_adam: float,
+ decay_pow: float,
+ beta1: float,
+ clip_threshold: Optional[float],
+ factored: bool,
+ epsilon1_grad_sq_reg: float,
+ quantized_dtype: jnp.dtype,
+ respect_skip_lp_regularization: bool,
+ exclude_from_layerwise_adaptation: Optional[List[str]],
+ per_var_learning_summary: bool,
+ sort_factored_second_moment_dims: bool,
+ min_dim_size_to_factor: int,
+ multiply_by_parameter_scale: bool,
+ epsilon2_param_scale_reg: float,
+ maybe_inf_to_nan: bool,
+ nesterov: bool,
+ ) -> None:
"""Constructor. See ShardedAdafactor() below."""
self._learning_rate = learning_rate
@@ -315,12 +321,13 @@ def should_store_momentum_in_qint(self, shape):
def to_state(self, count, result_tree):
"""Maps from a tree of (factored) values to separate trees of values."""
return ShardedAdafactorState(
- count=count,
- m=jax.tree.map(lambda o: o.m, result_tree),
- m_scale=jax.tree.map(lambda o: o.m_scale, result_tree),
- vr=jax.tree.map(lambda o: o.vr, result_tree),
- vc=jax.tree.map(lambda o: o.vc, result_tree),
- v=jax.tree.map(lambda o: o.v, result_tree))
+ count=count,
+ m=jax.tree.map(lambda o: o.m, result_tree),
+ m_scale=jax.tree.map(lambda o: o.m_scale, result_tree),
+ vr=jax.tree.map(lambda o: o.vr, result_tree),
+ vc=jax.tree.map(lambda o: o.vc, result_tree),
+ v=jax.tree.map(lambda o: o.v, result_tree),
+ )
def init(self, param):
"""Initializes the optimizer state for a given param."""
@@ -353,12 +360,13 @@ def init(self, param):
else:
output_v = jnp.zeros(shape, dtype=jnp.float32)
return _ShardedAdafactorUpdateResult(
- update=output_update,
- m=output_m,
- m_scale=output_m_scale,
- vr=output_vr,
- vc=output_vc,
- v=output_v)
+ update=output_update,
+ m=output_m,
+ m_scale=output_m_scale,
+ vr=output_vr,
+ vc=output_vc,
+ v=output_v,
+ )
def inf_to_nan(self, array):
"""Converting Infinity values to the more sticky NaN."""
@@ -386,16 +394,9 @@ def parameter_scale(self, var):
"""
return jnp.maximum(reduce_rms(var), jnp.asarray(self._epsilon2, var.dtype))
- def compute_var_and_slot_update(self,
- count,
- grad,
- m,
- m_scale,
- vr,
- vc,
- v,
- param,
- var_name=None):
+ def compute_var_and_slot_update(
+ self, count, grad, m, m_scale, vr, vc, v, param, var_name=None
+ ):
"""Computes the var and optimizer slots updates for a single variable."""
# We can probably skip this step
grad = grad.astype(jnp.float32)
@@ -434,7 +435,7 @@ def compute_var_and_slot_update(self,
update_scale += grad_squared_mean * 1e-30
# END HACK
- mixing_rate = 1. - decay_rate
+ mixing_rate = 1.0 - decay_rate
shape = param.shape
output_m = jnp.zeros((1,))
@@ -449,18 +450,23 @@ def compute_var_and_slot_update(self,
# reduce_mean().
vr_axis, vc_axis = factored_second_moment_dims
grad_squared_row_mean = self.inf_to_nan(
- jnp.mean(grad_squared, axis=vr_axis))
+ jnp.mean(grad_squared, axis=vr_axis)
+ )
grad_squared_col_mean = self.inf_to_nan(
- jnp.mean(grad_squared, axis=vc_axis))
+ jnp.mean(grad_squared, axis=vc_axis)
+ )
new_vr = decay_rate * vr + mixing_rate * grad_squared_row_mean
new_vc = decay_rate * vc + mixing_rate * grad_squared_col_mean
output_vr = new_vr
output_vc = new_vc
long_term_mean = jnp.mean(new_vr, axis=-1, keepdims=True)
- r_factor = 1. / jnp.sqrt(new_vr / long_term_mean)
- c_factor = 1. / jnp.sqrt(new_vc)
- x = grad * jnp.expand_dims(r_factor, vr_axis) * jnp.expand_dims(
- c_factor, vc_axis)
+ r_factor = 1.0 / jnp.sqrt(new_vr / long_term_mean)
+ c_factor = 1.0 / jnp.sqrt(new_vc)
+ x = (
+ grad
+ * jnp.expand_dims(r_factor, vr_axis)
+ * jnp.expand_dims(c_factor, vc_axis)
+ )
else:
# v with sharding annotation.
new_v = decay_rate * v + mixing_rate * grad_squared
@@ -468,7 +474,7 @@ def compute_var_and_slot_update(self,
x = grad / jnp.sqrt(new_v)
if self._clip_threshold is not None:
- clipping_denom = jnp.maximum(1., reduce_rms(x) / self._clip_threshold)
+ clipping_denom = jnp.maximum(1.0, reduce_rms(x) / self._clip_threshold)
clipping_denom = self.inf_to_nan(clipping_denom)
x /= clipping_denom
@@ -481,7 +487,7 @@ def compute_var_and_slot_update(self,
m = to_float(m, m_scale)
if self._nesterov:
subtrahend_original = subtrahend
- subtrahend = self._beta1 * m + (1. - self._beta1) * subtrahend
+ subtrahend = self._beta1 * m + (1.0 - self._beta1) * subtrahend
subtrahend = self.inf_to_nan(subtrahend)
if self._quantized_dtype == jnp.bfloat16:
new_m = subtrahend.astype(jnp.bfloat16)
@@ -496,8 +502,8 @@ def compute_var_and_slot_update(self,
if self._nesterov:
subtrahend = (
- self._beta1 * subtrahend +
- (1.0 - self._beta1) * subtrahend_original)
+ self._beta1 * subtrahend + (1.0 - self._beta1) * subtrahend_original
+ )
if self._weight_decay is not None:
# Apply decoupled weight decay to be consistent with AdamW.
@@ -527,43 +533,45 @@ def compute_var_and_slot_update(self,
g_norm = reduce_rms(subtrahend / update_scale) + self._epsilon1
ratio = w_norm / g_norm
ratio = jnp.where(
- jnp.greater(w_norm, 0),
- jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0),
- 1.0)
+ jnp.greater(w_norm, 0),
+ jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0),
+ 1.0,
+ )
subtrahend *= ratio
return _ShardedAdafactorUpdateResult(
- update=-subtrahend,
- m=output_m,
- m_scale=output_m_scale,
- vr=output_vr,
- vc=output_vc,
- v=output_v)
+ update=-subtrahend,
+ m=output_m,
+ m_scale=output_m_scale,
+ vr=output_vr,
+ vc=output_vc,
+ v=output_v,
+ )
def sharded_adafactor(
- learning_rate: optax.Schedule,
- weight_decay: Optional[Union[float, Dict[str, float]]] = None,
- layerwise_adaptation: bool = False,
- decay_method: str = 'adam',
- decay_adam: float = 0.99,
- decay_pow: float = 0.,
- beta1: float = 0.9,
- clip_threshold: Optional[float] = 1.,
- factored: bool = True,
- epsilon1_grad_sq_reg: float = 1e-30,
- quantized_dtype: jnp.dtype = jnp.int8,
- respect_skip_lp_regularization: bool = False,
- exclude_from_layerwise_adaptation: Optional[List[str]] = None,
- per_var_learning_summary: bool = False,
- sort_factored_second_moment_dims: bool = False,
- # min_dim_size_to_factor is only used when
- # sort_factored_second_moment_dims=True.
- min_dim_size_to_factor: int = 128,
- multiply_by_parameter_scale: bool = False,
- epsilon2_param_scale_reg: float = 1e-3,
- maybe_inf_to_nan: bool = True,
- nesterov: bool = False,
+ learning_rate: optax.Schedule,
+ weight_decay: Optional[Union[float, Dict[str, float]]] = None,
+ layerwise_adaptation: bool = False,
+ decay_method: str = 'adam',
+ decay_adam: float = 0.99,
+ decay_pow: float = 0.0,
+ beta1: float = 0.9,
+ clip_threshold: Optional[float] = 1.0,
+ factored: bool = True,
+ epsilon1_grad_sq_reg: float = 1e-30,
+ quantized_dtype: jnp.dtype = jnp.int8,
+ respect_skip_lp_regularization: bool = False,
+ exclude_from_layerwise_adaptation: Optional[List[str]] = None,
+ per_var_learning_summary: bool = False,
+ sort_factored_second_moment_dims: bool = False,
+ # min_dim_size_to_factor is only used when
+ # sort_factored_second_moment_dims=True.
+ min_dim_size_to_factor: int = 128,
+ multiply_by_parameter_scale: bool = False,
+ epsilon2_param_scale_reg: float = 1e-3,
+ maybe_inf_to_nan: bool = True,
+ nesterov: bool = False,
) -> optax.GradientTransformation:
"""AdaFactor optimizer that supports SPMD sharding.
@@ -638,53 +646,60 @@ def sharded_adafactor(
assert decay_pow >= 0
assert learning_rate is not None
assert decay_method == 'adam' or decay_method == 'pow', (
- f'decay_method: {decay_method} not supported. Supported methods are '
- '"pow", or "adam".')
+ f'decay_method: {decay_method} not supported. Supported methods are '
+ '"pow", or "adam".'
+ )
sharded_adafactor_helper = _ShardedAdafactorHelper(
- learning_rate=learning_rate,
- weight_decay=weight_decay,
- layerwise_adaptation=layerwise_adaptation,
- decay_method=decay_method,
- decay_adam=decay_adam,
- decay_pow=decay_pow,
- beta1=beta1,
- clip_threshold=clip_threshold,
- factored=factored,
- epsilon1_grad_sq_reg=epsilon1_grad_sq_reg,
- quantized_dtype=quantized_dtype,
- respect_skip_lp_regularization=respect_skip_lp_regularization,
- exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation,
- per_var_learning_summary=per_var_learning_summary,
- sort_factored_second_moment_dims=sort_factored_second_moment_dims,
- min_dim_size_to_factor=min_dim_size_to_factor,
- multiply_by_parameter_scale=multiply_by_parameter_scale,
- epsilon2_param_scale_reg=epsilon2_param_scale_reg,
- maybe_inf_to_nan=maybe_inf_to_nan,
- nesterov=nesterov)
+ learning_rate=learning_rate,
+ weight_decay=weight_decay,
+ layerwise_adaptation=layerwise_adaptation,
+ decay_method=decay_method,
+ decay_adam=decay_adam,
+ decay_pow=decay_pow,
+ beta1=beta1,
+ clip_threshold=clip_threshold,
+ factored=factored,
+ epsilon1_grad_sq_reg=epsilon1_grad_sq_reg,
+ quantized_dtype=quantized_dtype,
+ respect_skip_lp_regularization=respect_skip_lp_regularization,
+ exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation,
+ per_var_learning_summary=per_var_learning_summary,
+ sort_factored_second_moment_dims=sort_factored_second_moment_dims,
+ min_dim_size_to_factor=min_dim_size_to_factor,
+ multiply_by_parameter_scale=multiply_by_parameter_scale,
+ epsilon2_param_scale_reg=epsilon2_param_scale_reg,
+ maybe_inf_to_nan=maybe_inf_to_nan,
+ nesterov=nesterov,
+ )
def init_fn(params):
"""Initializes the optimizer's state."""
return sharded_adafactor_helper.to_state(
- jnp.zeros([], jnp.int32),
- jax.tree.map(sharded_adafactor_helper.init, params))
+ jnp.zeros([], jnp.int32),
+ jax.tree.map(sharded_adafactor_helper.init, params),
+ )
def update_fn(updates, state, params=None):
if params is None:
raise ValueError(
- 'You are using a transformation that requires the current value of '
- 'parameters, but you are not passing `params` when calling `update`.')
+ 'You are using a transformation that requires the current value of '
+ 'parameters, but you are not passing `params` when calling `update`.'
+ )
compute_var_and_slot_update_fn = functools.partial(
- sharded_adafactor_helper.compute_var_and_slot_update, state.count)
- output = jax.tree.map(compute_var_and_slot_update_fn,
- updates,
- state.m,
- state.m_scale,
- state.vr,
- state.vc,
- state.v,
- params)
+ sharded_adafactor_helper.compute_var_and_slot_update, state.count
+ )
+ output = jax.tree.map(
+ compute_var_and_slot_update_fn,
+ updates,
+ state.m,
+ state.m_scale,
+ state.vr,
+ state.vc,
+ state.v,
+ params,
+ )
updates = jax.tree.map(lambda o: o.update, output)
count_plus_one = state.count + jnp.array(1, jnp.int32)
updated_states = sharded_adafactor_helper.to_state(count_plus_one, output)
diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py
index 1833ab8af..8dcaa6578 100644
--- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py
+++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py
@@ -3,24 +3,27 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
-from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \
- sharded_adafactor
+from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import (
+ sharded_adafactor,
+)
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an Adafactor optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -30,99 +33,113 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = sharded_adafactor(
- learning_rate=lr_schedule_fn,
- beta1=1.0 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ beta1=1.0 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -139,37 +156,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -205,14 +228,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
index 7aa457a25..4c96e5562 100644
--- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
@@ -3,12 +3,10 @@
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -16,36 +14,40 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an Adafactor optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- Adafactor(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- beta1=1 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': Adafactor(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ beta1=1 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
optimizer = optimizer_state['optimizer']
warmup = LinearLR(
- optimizer,
- start_factor=1e-10,
- end_factor=1.,
- total_iters=hyperparameters.warmup_steps)
+ optimizer,
+ start_factor=1e-10,
+ end_factor=1.0,
+ total_iters=hyperparameters.warmup_steps,
+ )
cosine_steps = max(workload.step_hint - hyperparameters.warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
optimizer_state['scheduler'] = SequentialLR(
- optimizer,
- schedulers=[warmup, cosine_decay],
- milestones=[hyperparameters.warmup_steps])
+ optimizer,
+ schedulers=[warmup, cosine_decay],
+ milestones=[hyperparameters.warmup_steps],
+ )
return optimizer_state
@@ -54,56 +56,56 @@ class Adafactor(torch.optim.Optimizer):
src/transformers/optimization.py#L386"""
def __init__(
- self,
- params,
- lr=None,
- beta1=0.9,
- decay_adam=0.99,
- weight_decay=0.0,
+ self,
+ params,
+ lr=None,
+ beta1=0.9,
+ decay_adam=0.99,
+ weight_decay=0.0,
):
defaults = dict(
- lr=lr,
- beta1=beta1,
- decay_adam=decay_adam,
- weight_decay=weight_decay,
- decay_pow=0.0,
- layerwise_adaptation=False,
- decay_method='adam',
- clip_threshold=1.0,
- factored=True,
- epsilon1_grad_sq_reg=1e-30,
- respect_skip_lp_regularization=False,
- exclude_from_layerwise_adaptation=None,
- per_var_learning_summary=False,
- sort_factored_second_moment_dims=False,
- # Unused because sort_factored_second_moment_dims=False.
- min_dim_size_to_factor=128,
- multiply_by_parameter_scale=False,
- # Unused because multiply_by_parameter_scale=False.
- epsilon2_param_scale_reg=1e-3,
- maybe_inf_to_nan=True,
+ lr=lr,
+ beta1=beta1,
+ decay_adam=decay_adam,
+ weight_decay=weight_decay,
+ decay_pow=0.0,
+ layerwise_adaptation=False,
+ decay_method='adam',
+ clip_threshold=1.0,
+ factored=True,
+ epsilon1_grad_sq_reg=1e-30,
+ respect_skip_lp_regularization=False,
+ exclude_from_layerwise_adaptation=None,
+ per_var_learning_summary=False,
+ sort_factored_second_moment_dims=False,
+ # Unused because sort_factored_second_moment_dims=False.
+ min_dim_size_to_factor=128,
+ multiply_by_parameter_scale=False,
+ # Unused because multiply_by_parameter_scale=False.
+ epsilon2_param_scale_reg=1e-3,
+ maybe_inf_to_nan=True,
)
super().__init__(params, defaults)
def inf_to_nan(self, group, x):
- if group["maybe_inf_to_nan"]:
+ if group['maybe_inf_to_nan']:
x = torch.nan_to_num(x, nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
return x
def step(self, closure=None):
"""
- Performs a single optimization step
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
+ Performs a single optimization step
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
inf_to_nan = partial(self.inf_to_nan, group)
- for p in group["params"]:
+ for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
@@ -111,7 +113,7 @@ def step(self, closure=None):
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
- raise RuntimeError("Adafactor does not support sparse gradients.")
+ raise RuntimeError('Adafactor does not support sparse gradients.')
state = self.state[p]
grad_shape = grad.shape
@@ -120,51 +122,54 @@ def step(self, closure=None):
# State Initialization
if len(state) == 0:
- state["step"] = 0
- state["exp_avg"] = torch.zeros_like(grad)
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(grad)
if factored:
- state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
- state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] +
- grad_shape[-1:]).to(grad)
+ state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
+ state['exp_avg_sq_col'] = torch.zeros(
+ grad_shape[:-2] + grad_shape[-1:]
+ ).to(grad)
else:
- state["exp_avg_sq"] = torch.zeros_like(grad)
+ state['exp_avg_sq'] = torch.zeros_like(grad)
else:
- state["exp_avg"] = state["exp_avg"].to(grad)
+ state['exp_avg'] = state['exp_avg'].to(grad)
if factored:
- state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
- state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
+ state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
+ state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
else:
- state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+ state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
- state["step"] += 1
- lr = group["lr"]
- beta1 = group["beta1"]
- beta2 = group["decay_adam"]
+ state['step'] += 1
+ lr = group['lr']
+ beta1 = group['beta1']
+ beta2 = group['decay_adam']
- t = state["step"]
- beta2t = beta2 * (1. - beta2**(t - 1.)) / (1. - beta2**t)
+ t = state['step']
+ beta2t = beta2 * (1.0 - beta2 ** (t - 1.0)) / (1.0 - beta2**t)
- exp_avg_sq_update = (grad**2) + group["epsilon1_grad_sq_reg"]
+ exp_avg_sq_update = (grad**2) + group['epsilon1_grad_sq_reg']
if factored:
- exp_avg_sq_row = state["exp_avg_sq_row"]
- exp_avg_sq_col = state["exp_avg_sq_col"]
+ exp_avg_sq_row = state['exp_avg_sq_row']
+ exp_avg_sq_col = state['exp_avg_sq_col']
exp_avg_sq_row.mul_(beta2t).add_(
- exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t)
+ exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t
+ )
exp_avg_sq_col.mul_(beta2t).add_(
- exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t)
+ exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t
+ )
r_factor = inf_to_nan(
- exp_avg_sq_row /
- exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1)
+ exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)
+ ).unsqueeze(-1)
c_factor = inf_to_nan(exp_avg_sq_col).unsqueeze(-2)
denom = r_factor * c_factor
else:
- exp_avg_sq = state["exp_avg_sq"]
+ exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2t).add_(exp_avg_sq_update, alpha=1.0 - beta2t)
denom = exp_avg_sq
@@ -172,15 +177,16 @@ def step(self, closure=None):
denom = denom.sqrt()
update = grad / denom
# Clip the update based on RMS.
- clipping_denom = inf_to_nan(torch.square(update).mean().sqrt() \
- /group["clip_threshold"]).clamp(min=1.0)
+ clipping_denom = inf_to_nan(
+ torch.square(update).mean().sqrt() / group['clip_threshold']
+ ).clamp(min=1.0)
update = update / clipping_denom * lr
# Momentum
- exp_avg = state["exp_avg"]
+ exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(update, alpha=1 - beta1)
- if group["weight_decay"] != 0:
- p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * lr)
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr)
p_data_fp32.add_(-exp_avg)
@@ -191,18 +197,19 @@ def step(self, closure=None):
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -214,22 +221,26 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -243,12 +254,14 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -256,28 +269,34 @@ def update_params(
if global_step <= 100 or global_step % 500 == 0:
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -314,14 +333,15 @@ def get_batch_size(workload_name):
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json
index 5543689ea..37b36e55d 100644
--- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json
@@ -1,20 +1,26 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 1e-2, "max": 0.45, "scaling": "log"
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 1e-2,
+ "max": 0.45,
+ "scaling": "log"
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json
index 98a506084..35f106f9d 100644
--- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json
@@ -1,20 +1,24 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py
index dde41fa6d..17d0c2fc2 100644
--- a/reference_algorithms/paper_baselines/adamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py
@@ -3,22 +3,24 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,101 +30,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = optax.adamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -139,37 +155,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -209,14 +231,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
index 21d9b6b57..5df907160 100644
--- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
@@ -2,12 +2,10 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -15,55 +13,60 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- torch.optim.AdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay,
- fused=False),
+ 'optimizer': torch.optim.AdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ fused=False,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = hyperparameters.warmup_factor * step_hint
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -75,22 +78,26 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -104,7 +111,8 @@ def update_params(
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -113,31 +121,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -177,14 +192,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json
index c96b03eda..abdd6e32d 100644
--- a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json
@@ -1,23 +1,29 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 2e-2, "max": 0.5, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 2e-2,
+ "max": 0.5,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json
index b8bd2ea49..5a7c27be7 100644
--- a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py
index 70e305514..168c0579b 100644
--- a/reference_algorithms/paper_baselines/lamb/jax/submission.py
+++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -21,11 +21,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a LAMB optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -35,61 +37,70 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = optax.lamb(
- learning_rate=lr_schedule_fn,
- b1=1 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
@@ -97,40 +108,45 @@ def _loss_fn(params):
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -147,37 +163,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -213,14 +235,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
index c1c6cec0a..c73f89e71 100644
--- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
@@ -3,25 +3,19 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
+from absl import logging
from torch import Tensor
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py
class LAMB(torch.optim.Optimizer):
-
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=0.0):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -39,7 +33,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -74,48 +69,53 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
lamb(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def lamb(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float):
-
+def lamb(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+):
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -147,61 +147,67 @@ def lamb(params: List[Tensor],
update_norm = torch.linalg.norm(update)
# Set trust_ratio to 1 in case where parameters would never be updated.
- if param_norm == 0. or update_norm == 0.:
- trust_ratio = 1.
+ if param_norm == 0.0 or update_norm == 0.0:
+ trust_ratio = 1.0
else:
trust_ratio = param_norm / update_norm
param.add_(update, alpha=-lr * trust_ratio)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a LAMB optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- LAMB(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(hyperparameters.beta1, hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
+ 'optimizer': LAMB(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(hyperparameters.beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -213,31 +219,36 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss, _ = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
loss.backward()
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -246,31 +257,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -306,14 +324,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json
index f2fcde461..7de33fe47 100644
--- a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json
@@ -1,23 +1,29 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 5e-2, "max": 0.3, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 5e-2,
+ "max": 0.3,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json
index 8934e512d..0f1bc208a 100644
--- a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py
index cbb6d6dcd..df084c17b 100644
--- a/reference_algorithms/paper_baselines/momentum/jax/submission.py
+++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py
@@ -3,22 +3,24 @@
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload,
lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters)
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- nesterov=False)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ nesterov=False,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
@@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -173,37 +191,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -243,14 +267,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
index c3760d20e..b3d38b3dd 100644
--- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
@@ -2,10 +2,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import optax
import torch
import torch.distributed.nn as dist_nn
+from absl import logging
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
@@ -14,24 +14,26 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=False),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=False,
+ ),
}
# Create learning rate schedule.
@@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -91,26 +98,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -123,7 +134,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -132,31 +144,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -196,14 +215,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json
index 8423bdab7..9ec39a6ef 100644
--- a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json
@@ -1,21 +1,27 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "min": 5e-3, "max": 0.3, "scaling": "log"
+ "min": 5e-3,
+ "max": 0.3,
+ "scaling": "log"
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json
index f874862d8..80f9c7968 100644
--- a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json
@@ -1,21 +1,25 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "feasible_points": [0.1]
+ "feasible_points": [0.1]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
index c451a18ac..62161b3d5 100644
--- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
@@ -1,26 +1,24 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
-# isort: on
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -30,15 +28,14 @@
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -73,19 +70,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -124,7 +124,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -132,6 +133,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -140,7 +142,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -156,11 +159,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -170,101 +175,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -281,37 +300,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -351,14 +376,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
index a2f9fb4c5..f6c2faa9d 100644
--- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
@@ -3,13 +3,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -21,33 +19,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -59,7 +54,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -67,7 +65,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -76,9 +75,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -107,51 +106,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -189,54 +194,59 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -248,26 +258,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -280,7 +294,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -289,31 +304,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -353,14 +375,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json
index cba20c4c2..a3d322771 100644
--- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json
@@ -1,23 +1,29 @@
{
"learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
},
"one_minus_beta1": {
- "min": 4e-3, "max": 0.1, "scaling": "log"
+ "min": 4e-3,
+ "max": 0.1,
+ "scaling": "log"
},
"beta2": {
- "feasible_points": [0.999]
+ "feasible_points": [0.999]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
}
}
diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json
index 58973eb27..5a7c27be7 100644
--- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
"learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
},
"one_minus_beta1": {
- "feasible_points": [0.1]
+ "feasible_points": [0.1]
},
"beta2": {
- "feasible_points": [0.999]
+ "feasible_points": [0.999]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
}
}
diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py
index 0e53aae42..18e58b3c0 100644
--- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py
@@ -3,22 +3,24 @@
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload,
lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters)
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- nesterov=True)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ nesterov=True,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
@@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -173,37 +191,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -243,14 +267,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
index b4432fbff..9d3bfa6e7 100644
--- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
@@ -2,10 +2,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import optax
import torch
import torch.distributed.nn as dist_nn
+from absl import logging
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
@@ -14,24 +14,26 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=True),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=True,
+ ),
}
# Create learning rate schedule.
@@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -91,26 +98,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -123,7 +134,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -132,31 +144,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -196,14 +215,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json
index 8423bdab7..9ec39a6ef 100644
--- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json
@@ -1,21 +1,27 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "min": 5e-3, "max": 0.3, "scaling": "log"
+ "min": 5e-3,
+ "max": 0.3,
+ "scaling": "log"
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json
index f874862d8..80f9c7968 100644
--- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json
@@ -1,21 +1,25 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "feasible_points": [0.1]
+ "feasible_points": [0.1]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py
index b76589705..9ab193b56 100644
--- a/reference_algorithms/paper_baselines/sam/jax/submission.py
+++ b/reference_algorithms/paper_baselines/sam/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
@@ -23,7 +23,8 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray:
y: A pytree of numpy ndarray, vector y in the equation above.
"""
gradient_norm = jnp.sqrt(
- sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)))
+ sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))
+ )
normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y)
return normalized_gradient
@@ -31,11 +32,11 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray:
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/
# sharpness_aware_minimization.py
def sharpness_aware_minimization(
- rho: float,
- grad_clip: Optional[float],
- batch_axis_name: str,
- base_opt_init_fn,
- base_opt_update_fn,
+ rho: float,
+ grad_clip: Optional[float],
+ batch_axis_name: str,
+ base_opt_init_fn,
+ base_opt_update_fn,
) -> optax.GradientTransformation:
"""Implementation of Sharpness Aware Minimization (SAM).
Paper: https://arxiv.org/abs/2010.01412
@@ -68,22 +69,28 @@ def update_fn(updates, state, grad_fn_params_tuple):
# with the same 1e-6 epsilon that is used when clipping the gradients.
updates = dual_vector(updates)
noised_params = jax.tree_util.tree_map(
- lambda p, u: p + rho * u, params, updates)
+ lambda p, u: p + rho * u, params, updates
+ )
(_, (n_valid_examples, _)), updates = grad_fn(noised_params)
# Get correct global mean grad.
- (n_valid_examples, updates) = lax.psum((n_valid_examples, updates),
- axis_name=batch_axis_name)
+ (n_valid_examples, updates) = lax.psum(
+ (n_valid_examples, updates), axis_name=batch_axis_name
+ )
updates = jax.tree.map(lambda x: x / n_valid_examples, updates)
if grad_clip:
updates_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))
+ )
scaled_updates = jax.tree.map(
- lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
- updates = jax.lax.cond(updates_norm > grad_clip,
- lambda _: scaled_updates,
- lambda _: updates,
- None)
+ lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates
+ )
+ updates = jax.lax.cond(
+ updates_norm > grad_clip,
+ lambda _: scaled_updates,
+ lambda _: updates,
+ None,
+ )
updates, state = base_opt_update_fn(updates, state, params)
return updates, state
@@ -91,11 +98,13 @@ def update_fn(updates, state, grad_fn_params_tuple):
return optax.GradientTransformation(init_fn, update_fn)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a SAM optimizer (with AdamW base) and a learning rate schedule."""
del model_params
del model_state
@@ -105,111 +114,127 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create base optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = optax.adamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
# Create SAM update fn.
grad_clip = (
- hyperparameters.grad_clip
- if hasattr(hyperparameters, 'grad_clip') else None)
+ hyperparameters.grad_clip if hasattr(hyperparameters, 'grad_clip') else None
+ )
opt_init_fn, opt_update_fn = sharpness_aware_minimization(
- rho=hyperparameters.rho,
- grad_clip=grad_clip,
- batch_axis_name='batch',
- base_opt_init_fn=opt_init_fn,
- base_opt_update_fn=opt_update_fn)
+ rho=hyperparameters.rho,
+ grad_clip=grad_clip,
+ batch_axis_name='batch',
+ base_opt_init_fn=opt_init_fn,
+ base_opt_update_fn=opt_update_fn,
+ )
# Initialize optimizer state.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params, update_batch_norm=True):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=update_batch_norm)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=update_batch_norm,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
second_grad_fn = jax.value_and_grad(
- functools.partial(_loss_fn, update_batch_norm=False), has_aux=True)
+ functools.partial(_loss_fn, update_batch_norm=False), has_aux=True
+ )
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
updates, new_optimizer_state = opt_update_fn(
- grad, optimizer_state, (second_grad_fn, current_param_container))
+ grad, optimizer_state, (second_grad_fn, current_param_container)
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -226,37 +251,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -293,14 +324,15 @@ def get_batch_size(workload_name):
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py
index 92603f036..652ebed1d 100644
--- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py
@@ -2,12 +2,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -17,13 +15,14 @@
# Modified from https://github.com/davda54/sam.
class SAM(torch.optim.Optimizer):
-
- def __init__(self,
- params: spec.ParameterContainer,
- base_optimizer: torch.optim.Optimizer,
- rho: float = 0.05,
- adaptive: bool = False,
- **kwargs):
+ def __init__(
+ self,
+ params: spec.ParameterContainer,
+ base_optimizer: torch.optim.Optimizer,
+ rho: float = 0.05,
+ adaptive: bool = False,
+ **kwargs,
+ ):
if rho < 0.0:
raise ValueError(f'Invalid rho, should be non-negative: {rho}')
@@ -79,12 +78,18 @@ def _grad_norm(self):
# In case of model parallelism, put everything on the same device.
shared_device = self.param_groups[0]['params'][0].device
norm = torch.norm(
- torch.stack([((torch.abs(p) if group['adaptive'] else 1.0) *
- p.grad).norm(p=2).to(shared_device)
- for group in self.param_groups
- for p in group['params']
- if p.grad is not None]),
- p=2)
+ torch.stack(
+ [
+ ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad)
+ .norm(p=2)
+ .to(shared_device)
+ for group in self.param_groups
+ for p in group['params']
+ if p.grad is not None
+ ]
+ ),
+ p=2,
+ )
return norm
def load_state_dict(self, state_dict: Dict):
@@ -92,11 +97,13 @@ def load_state_dict(self, state_dict: Dict):
self.base_optimizer.param_groups = self.param_groups
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng
@@ -104,46 +111,50 @@ def init_optimizer_state(workload: spec.Workload,
# Create SAM optimizer with AdamW base.
base_optimizer = torch.optim.AdamW
optimizer_state = {
- 'optimizer':
- SAM(model_params.parameters(),
- base_optimizer=base_optimizer,
- rho=hyperparameters.rho,
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': SAM(
+ model_params.parameters(),
+ base_optimizer=base_optimizer,
+ rho=hyperparameters.rho,
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
# Create learning rate schedule.
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -156,20 +167,24 @@ def update_params(
def _loss_fn(params, update_batch_norm=True):
"""Loss function used for training."""
logits_batch, new_model_state = workload.model_fn(
- params=params,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=update_batch_norm)
+ params=params,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=update_batch_norm,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -187,7 +202,8 @@ def _loss_fn(params, update_batch_norm=True):
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
optimizer_state['optimizer'].first_step(zero_grad=True)
@@ -198,7 +214,8 @@ def _loss_fn(params, update_batch_norm=True):
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].second_step(zero_grad=True)
optimizer_state['scheduler'].step()
@@ -206,29 +223,34 @@ def _loss_fn(params, update_batch_norm=True):
if global_step <= 100 or global_step % 500 == 0:
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': logging_loss.item(),
- 'grad_norm': grad_norm.item(),
- },
- global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- logging_loss.item(),
- grad_norm.item())
+ {
+ 'loss': logging_loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ logging_loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -265,14 +287,15 @@ def get_batch_size(workload_name):
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space.json b/reference_algorithms/paper_baselines/sam/tuning_search_space.json
index 66dae232b..f32058937 100644
--- a/reference_algorithms/paper_baselines/sam/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/sam/tuning_search_space.json
@@ -1,26 +1,32 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 5e-2, "max": 0.43, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-2, "max": 0.2, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- },
- "rho": {
- "feasible_points": [0.01, 0.02, 0.05]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 5e-2,
+ "max": 0.43,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-2,
+ "max": 0.2,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ },
+ "rho": {
+ "feasible_points": [0.01, 0.02, 0.05]
+ }
}
diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json
index 89c480e7a..ee4e0c3e4 100644
--- a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json
@@ -1,26 +1,30 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-2, "max": 0.2, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- },
- "rho": {
- "feasible_points": [0.01, 0.02, 0.05]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-2,
+ "max": 0.2,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ },
+ "rho": {
+ "feasible_points": [0.01, 0.02, 0.05]
+ }
}
diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
index a5c2732ac..c719361d3 100644
--- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
+++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
@@ -36,17 +36,17 @@
import functools
import itertools
import logging
-from typing import Any, cast, List, NamedTuple, Optional, TypeVar, Union
+from typing import Any, List, NamedTuple, Optional, TypeVar, Union, cast
import chex
-from flax import struct
import jax
-from jax import lax
-from jax.experimental import pjit
-from jax.experimental.sparse import linalg
import jax.numpy as jnp
import numpy as np
import optax
+from flax import struct
+from jax import lax
+from jax.experimental import pjit
+from jax.experimental.sparse import linalg
# Dtype for inverse-pth root routine
# Switch to f64 if you have hardware that supports it. Enable the jax flag
@@ -61,13 +61,16 @@
@struct.dataclass
class QuantizedValue:
"""State associated with quantized value."""
+
quantized: chex.Array
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
bucket_size: chex.Array
quantized_dtype: jnp.dtype = struct.field(
- pytree_node=False) # Dtype for the quantized value.
+ pytree_node=False
+ ) # Dtype for the quantized value.
extract_diagonal: bool = struct.field(
- pytree_node=False) # In case its centered.
+ pytree_node=False
+ ) # In case its centered.
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
@classmethod
@@ -75,13 +78,16 @@ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
if isinstance(fvalue, list) and not fvalue:
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
- fvalue, quantized_dtype, extract_diagonal)
- return QuantizedValue(quantized,
- diagonal_fvalue,
- bucket_size,
- quantized_dtype,
- extract_diagonal,
- list(quantized.shape))
+ fvalue, quantized_dtype, extract_diagonal
+ )
+ return QuantizedValue(
+ quantized,
+ diagonal_fvalue,
+ bucket_size,
+ quantized_dtype,
+ extract_diagonal,
+ list(quantized.shape),
+ )
# Quantization is from Lingvo JAX optimizers.
# We extend it for int16 quantization of PSD matrices.
@@ -106,7 +112,8 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
if extract_diagonal and fvalue.ndim != 2:
raise ValueError(
- f'Input array {fvalue} must be 2D to work with extract_diagonal.')
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.'
+ )
diagonal_fvalue = []
if extract_diagonal:
@@ -119,16 +126,17 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
# We first decide the scale.
if fvalue.ndim < 1:
raise ValueError(
- f'Input array {fvalue} must have a strictly positive number of '
- 'dimensions.')
+ f'Input array {fvalue} must have a strictly positive number of '
+ 'dimensions.'
+ )
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
bucket_size = max_abs / num_buckets
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
# To avoid divide by 0.0
- bs_nonzero = jnp.where(bs_expanded > 0.0,
- bs_expanded,
- jnp.ones_like(bs_expanded))
+ bs_nonzero = jnp.where(
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
+ )
ratio = fvalue / bs_nonzero
# We use rounding to remove bias.
quantized = jnp.round(ratio)
@@ -155,10 +163,11 @@ def to_float(self):
def _default_zero_field():
return struct.field(
- default_factory=functools.partial(jnp.array, 0, jnp.float32))
+ default_factory=functools.partial(jnp.array, 0, jnp.float32)
+ )
-T = TypeVar("T")
+T = TypeVar('T')
def _maybe_ix(ls, ix):
@@ -180,17 +189,19 @@ def wrap_f(x, *args, **kwargs):
InversePthRootDiagnosticsSubtype = TypeVar(
- "InversePthRootDiagnosticsSubtype", bound="InversePthRootDiagnostics")
+ 'InversePthRootDiagnosticsSubtype', bound='InversePthRootDiagnostics'
+)
@struct.dataclass
class InversePthRootDiagnostics:
"""Diagnostics for inverse p-th root iterative procedure.
- Given an inverse pth root B = A^(-1/p), contains the average and
- maximum diagonal and off diagonal absolute entrywise errors between
- (B^p A) and I.
- """
+ Given an inverse pth root B = A^(-1/p), contains the average and
+ maximum diagonal and off diagonal absolute entrywise errors between
+ (B^p A) and I.
+ """
+
max_diag_error: chex.Array = _default_zero_field()
avg_diag_error: chex.Array = _default_zero_field()
max_off_diag_error: chex.Array = _default_zero_field()
@@ -201,35 +212,41 @@ class InversePthRootDiagnostics:
def create(cls, pth_inverse_root, matrix, p):
"""Generates a diagnostics struct from (-1/p) root result."""
mat_m = jnp.matmul(
- mat_power(pth_inverse_root, p),
- matrix,
- precision=jax.lax.Precision.HIGHEST)
+ mat_power(pth_inverse_root, p),
+ matrix,
+ precision=jax.lax.Precision.HIGHEST,
+ )
num_off_diag_entries = mat_m.size - jnp.diag(mat_m).size
diag_error = jnp.abs(jnp.diag(mat_m) - 1).astype(jnp.float32)
off_diag_error = jnp.abs(mat_m - jnp.diag(jnp.diag(mat_m))).astype(
- jnp.float32)
+ jnp.float32
+ )
return cls(
- max_diag_error=jnp.max(diag_error).astype(jnp.float32),
- avg_diag_error=jnp.mean(diag_error).astype(jnp.float32),
- max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32),
- avg_off_diag_error=(jnp.sum(off_diag_error) /
- num_off_diag_entries).astype(jnp.float32),
- p=jnp.array(p, jnp.float32))
+ max_diag_error=jnp.max(diag_error).astype(jnp.float32),
+ avg_diag_error=jnp.mean(diag_error).astype(jnp.float32),
+ max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32),
+ avg_off_diag_error=(
+ jnp.sum(off_diag_error) / num_off_diag_entries
+ ).astype(jnp.float32),
+ p=jnp.array(p, jnp.float32),
+ )
LOBPCGDiagnosticsSubtype = TypeVar(
- "LOBPCGDiagnosticsSubtype", bound="LOBPCGDiagnostics")
+ 'LOBPCGDiagnosticsSubtype', bound='LOBPCGDiagnostics'
+)
@struct.dataclass
class LOBPCGDiagnostics:
"""Diagnostics for iterative LOBPCG eigenvalue routine.
- Contains consistency error for LOBPCG eigenvalue routine, which
- refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair
- (v, lambda). This metics dataclass retains consistency error
- and other useful LOBPCG values.
- """
+ Contains consistency error for LOBPCG eigenvalue routine, which
+ refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair
+ (v, lambda). This metics dataclass retains consistency error
+ and other useful LOBPCG values.
+ """
+
lobpcg_iters: chex.Array = _default_zero_field()
max_consistency_error: chex.Array = _default_zero_field()
avg_consistency_error: chex.Array = _default_zero_field()
@@ -248,7 +265,8 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters):
mat_eigvecs = matrix.dot(eigvecs, precision=precision)
consistency_error_unnormalized = jnp.linalg.norm(
- mat_eigvecs - eigvals * eigvecs, ord=2, axis=0)
+ mat_eigvecs - eigvals * eigvecs, ord=2, axis=0
+ )
normalization = jnp.linalg.norm(mat_eigvecs, ord=2, axis=0) + eigvals
consistency_error = consistency_error_unnormalized / normalization
@@ -256,20 +274,22 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters):
orthogonality_error -= jnp.diag(jnp.diag(orthogonality_error))
return cls(
- lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32),
- max_consistency_error=jnp.max(consistency_error).astype(jnp.float32),
- avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32),
- avg_orthogonality_error=(jnp.sum(orthogonality_error) /
- num_off_diag).astype(jnp.float32),
- max_eigenvalue=jnp.max(eigvals).astype(jnp.float32),
- min_eigenvalue=jnp.min(eigvals).astype(jnp.float32),
- num_topk_eigenvectors=jnp.array(num_topk, jnp.float32),
+ lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32),
+ max_consistency_error=jnp.max(consistency_error).astype(jnp.float32),
+ avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32),
+ avg_orthogonality_error=(
+ jnp.sum(orthogonality_error) / num_off_diag
+ ).astype(jnp.float32),
+ max_eigenvalue=jnp.max(eigvals).astype(jnp.float32),
+ min_eigenvalue=jnp.min(eigvals).astype(jnp.float32),
+ num_topk_eigenvectors=jnp.array(num_topk, jnp.float32),
)
@struct.dataclass
class TrainingMetrics:
"""Diagnostic metrics from training."""
+
# Error for inverse-pth roots.
inverse_pth_root_errors: chex.Array = _default_zero_field()
# Iteration count for inverse-pth roots.
@@ -283,20 +303,24 @@ class TrainingMetrics:
total_retries: chex.Array = _default_zero_field()
lobpcg_diagnostics: LOBPCGDiagnostics = struct.field(
- default_factory=LOBPCGDiagnostics)
+ default_factory=LOBPCGDiagnostics
+ )
# Rich matrix entrywise error diagnostics, if enabled.
inverse_pth_root_diagnostics: InversePthRootDiagnostics = struct.field(
- default_factory=InversePthRootDiagnostics)
+ default_factory=InversePthRootDiagnostics
+ )
# Diagnostics applied to the conditioned p-th root problem, after top
# eigenvectors are removed, if LOBPCG is being applied.
conditioned_inverse_pth_root_diagnostics: InversePthRootDiagnostics = (
- struct.field(default_factory=InversePthRootDiagnostics))
+ struct.field(default_factory=InversePthRootDiagnostics)
+ )
# TODO(rohananil): Add more important metrics to track during training.
# Per parameter optimizer state used in data-parallel training.
class ParameterStats(NamedTuple):
"""State associated to each parameter of the model being trained."""
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
statistics: Optional[List[Any]] # Statistics (QuantizedValue, chex.Array)
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
@@ -321,12 +345,14 @@ class GlobalShardedParameterStats:
@struct.dataclass
class LocalShardedParameterStats:
"""State associated to each parameter of the model being trained."""
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
momentum: QuantizedValue # Momentum for the shampoo preconditioner
training_metrics: Union[TrainingMetrics, optax.MaskedNode]
index_start: Union[np.int32, int] = struct.field(
- pytree_node=False) # Index into global statistics array
+ pytree_node=False
+ ) # Index into global statistics array
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
@@ -336,39 +362,44 @@ def default_training_metrics():
def init_training_metrics(
- num_statistics,
- generate_training_metrics,
+ num_statistics,
+ generate_training_metrics,
):
"""Initialize TrainingMetrics, masked if disabled."""
if not generate_training_metrics:
return optax.MaskedNode()
return jax.tree.map(
- functools.partial(jnp.repeat, repeats=num_statistics),
- default_training_metrics())
+ functools.partial(jnp.repeat, repeats=num_statistics),
+ default_training_metrics(),
+ )
def init_training_metrics_shapes(
- num_statistics,
- generate_training_metrics,
+ num_statistics,
+ generate_training_metrics,
):
"""Initialize training metrics shape/dtype."""
seed = init_training_metrics(
- num_statistics,
- generate_training_metrics,
+ num_statistics,
+ generate_training_metrics,
)
return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed)
-def init_training_metrics_pspec(generate_training_metrics,):
+def init_training_metrics_pspec(
+ generate_training_metrics,
+):
"""Initialize training metrics partition specification."""
if not generate_training_metrics:
return optax.MaskedNode()
- return jax.tree.map(lambda _: jax.sharding.PartitionSpec(),
- default_training_metrics())
+ return jax.tree.map(
+ lambda _: jax.sharding.PartitionSpec(), default_training_metrics()
+ )
class ShardedShampooStats(NamedTuple):
"""Shampoo state in sharded mode."""
+
global_stats: Any
local_stats: Any
@@ -406,35 +437,35 @@ class PreconditionerType(enum.IntEnum):
def power_iteration(
- matrix,
- num_iters=100,
- error_tolerance=1e-6,
- precision=lax.Precision.HIGHEST,
- padding_start=None,
+ matrix,
+ num_iters=100,
+ error_tolerance=1e-6,
+ precision=lax.Precision.HIGHEST,
+ padding_start=None,
):
r"""Power iteration algorithm.
- The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
- a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
- of `A`, and a vector v, which is the corresponding eigenvector of `A`.
-
- References:
- [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
-
- Args:
- matrix: the symmetric PSD matrix.
- num_iters: Number of iterations.
- error_tolerance: Iterative exit condition.
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- padding_start: if set, assumes rows and columns after padding_start are
- zero.
-
- Returns:
- eigen vector, eigen value
- """
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
+
+ References:
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
+
+ Args:
+ matrix: the symmetric PSD matrix.
+ num_iters: Number of iterations.
+ error_tolerance: Iterative exit condition.
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ padding_start: if set, assumes rows and columns after padding_start are
+ zero.
+
+ Returns:
+ eigen vector, eigen value
+ """
matrix_size = matrix.shape[-1]
def _iter_condition(state):
@@ -446,32 +477,38 @@ def _iter_body(state):
i, new_v, s, s_v, unused_run_step = state
new_v = new_v / jnp.linalg.norm(new_v)
- s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
- s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
- return (i + 1,
- s_v,
- s_new,
- s_v,
- jnp.greater(jnp.abs(s_new - s), error_tolerance))
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
+ return (
+ i + 1,
+ s_v,
+ s_new,
+ s_v,
+ jnp.greater(jnp.abs(s_new - s), error_tolerance),
+ )
# Figure out how to use step as seed for random.
- v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
- matrix_size).astype(matrix.dtype)
+ v_0 = (
+ np.random.RandomState(1729)
+ .uniform(-1.0, 1.0, matrix_size)
+ .astype(matrix.dtype)
+ )
v_0 = jnp.array(v_0)
if padding_start is not None:
- v_0 *= (jnp.arange(len(v_0), dtype=jnp.int32) < padding_start)
+ v_0 *= jnp.arange(len(v_0), dtype=jnp.int32) < padding_start
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
- _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body,
- init_state)
+ _, v_out, s_out, _, _ = lax.while_loop(
+ _iter_condition, _iter_body, init_state
+ )
v_out = v_out / jnp.linalg.norm(v_out)
return v_out, s_out
def mat_power(
- mat_m,
- p,
- precision=lax.Precision.HIGHEST,
+ mat_m,
+ p,
+ precision=lax.Precision.HIGHEST,
):
"""A simple matrix power method. M^p where p can be TracedValue."""
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
@@ -483,9 +520,11 @@ def _iter_condition(state):
def _iter_body(state):
i, power, mat = state
- power = jax.lax.cond(i % 2 == 1,
- lambda: jnp.matmul(mat, power, precision=precision),
- lambda: power)
+ power = jax.lax.cond(
+ i % 2 == 1,
+ lambda: jnp.matmul(mat, power, precision=precision),
+ lambda: power,
+ )
i //= 2
mat = jnp.matmul(mat, mat, precision=precision)
return i, power, mat
@@ -508,78 +547,81 @@ def _stable_subtract(b, a_minus_b):
return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b))
return jnp.where(
- # Choose the branch with the best log1p approximation.
- jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a),
- -_stable_subtract(a, -a_minus_b),
- _stable_subtract(b, a_minus_b))
+ # Choose the branch with the best log1p approximation.
+ jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a),
+ -_stable_subtract(a, -a_minus_b),
+ _stable_subtract(b, a_minus_b),
+ )
def matrix_inverse_pth_root(
- matrix,
- p,
- num_iters=100,
- ridge_epsilon=1e-6,
- error_tolerance=1e-6,
- precision=lax.Precision.HIGHEST,
- relative_matrix_epsilon=True,
- lobpcg_topk_precondition=0,
- lobpcg_max_iter=0,
- padding_start=None,
- prev=None,
- eigh=False,
+ matrix,
+ p,
+ num_iters=100,
+ ridge_epsilon=1e-6,
+ error_tolerance=1e-6,
+ precision=lax.Precision.HIGHEST,
+ relative_matrix_epsilon=True,
+ lobpcg_topk_precondition=0,
+ lobpcg_max_iter=0,
+ padding_start=None,
+ prev=None,
+ eigh=False,
):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
- This function uses the Eigh or Coupled newton iterations algorithm for
- the computation of a matrix's inverse pth root.
-
-
- References:
- [Functions of Matrices, Theory and Computation,
- Nicholas J Higham, Pg 184, Eq 7.18](
- https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
-
- Args:
- matrix: the symmetric PSD matrix whose power it to be computed
- p: exponent, for p a positive integer.
- num_iters: Maximum number of iterations.
- ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
- error_tolerance: Error indicator, useful for early termination.
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
- value when computing inverse-pth root.
- lobpcg_topk_precondition: If nonzero, specifies the number of top
- eigenvectors to subtract out before performing LOBPCG. Note this makes
- relative_matrix_epsilon essentially free.
- lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to
- `lobpcg_topk_precondition`.
- padding_start: If the input matrix was padded, then zeros out columns and
- rows at the padding start.
- prev: previous iteration's solution, zero-padded (unused)
- eigh: If True, uses eigh for inverse-pth root computation.
-
- Returns:
- `(matrix + eps)^(-1/p)` and error metrics.
-
- Note `eps` is not added to zeroed out padding rows and
- columns. `eps` is just `ridge_epsilon` if
- `relative_matrix_epsilon` is set to `False`, otherwise, it is the
- ridge epsilon value scaled by the derived maximum eigenvalue of
- the input matrix.
- """
+ This function uses the Eigh or Coupled newton iterations algorithm for
+ the computation of a matrix's inverse pth root.
+
+
+ References:
+ [Functions of Matrices, Theory and Computation,
+ Nicholas J Higham, Pg 184, Eq 7.18](
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
+
+ Args:
+ matrix: the symmetric PSD matrix whose power it to be computed
+ p: exponent, for p a positive integer.
+ num_iters: Maximum number of iterations.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+ error_tolerance: Error indicator, useful for early termination.
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
+ value when computing inverse-pth root.
+ lobpcg_topk_precondition: If nonzero, specifies the number of top
+ eigenvectors to subtract out before performing LOBPCG. Note this makes
+ relative_matrix_epsilon essentially free.
+ lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to
+ `lobpcg_topk_precondition`.
+ padding_start: If the input matrix was padded, then zeros out columns and
+ rows at the padding start.
+ prev: previous iteration's solution, zero-padded (unused)
+ eigh: If True, uses eigh for inverse-pth root computation.
+
+ Returns:
+ `(matrix + eps)^(-1/p)` and error metrics.
+
+ Note `eps` is not added to zeroed out padding rows and
+ columns. `eps` is just `ridge_epsilon` if
+ `relative_matrix_epsilon` is set to `False`, otherwise, it is the
+ ridge epsilon value scaled by the derived maximum eigenvalue of
+ the input matrix.
+ """
if eigh:
- return matrix_inverse_pth_root_eigh(matrix,
- p,
- ridge_epsilon,
- error_tolerance,
- precision,
- relative_matrix_epsilon,
- padding_start,
- prev)
+ return matrix_inverse_pth_root_eigh(
+ matrix,
+ p,
+ ridge_epsilon,
+ error_tolerance,
+ precision,
+ relative_matrix_epsilon,
+ padding_start,
+ prev,
+ )
del prev
assert matrix.shape[0] == matrix.shape[1]
@@ -596,7 +638,8 @@ def matrix_inverse_pth_root(
if padding_start is not None:
# Zero out padding in identity as well for convergence checks.
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
- matrix.dtype)
+ matrix.dtype
+ )
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
@@ -607,18 +650,23 @@ def matrix_inverse_pth_root(
eigvals, eigvecs, lobpcg_diagnostics = None, None, None
if lobpcg_topk_precondition > 0:
# TODO(vladf): reuse previous top-k as the initial search directions
- pad_shape = (matrix_size - lobpcg_topk_precondition,
- lobpcg_topk_precondition)
+ pad_shape = (
+ matrix_size - lobpcg_topk_precondition,
+ lobpcg_topk_precondition,
+ )
search_dirs = jnp.concatenate(
- (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0)
+ (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0
+ )
eigvals, eigvecs, lobpcg_iters = linalg.lobpcg_standard( # pylint: disable=unbalanced-tuple-unpacking
- matrix, search_dirs,
- lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter)
+ matrix,
+ search_dirs,
+ lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter,
+ )
lobpcg_diagnostics = LOBPCGDiagnostics.create(
- matrix,
- eigvals,
- eigvecs,
- lobpcg_iters,
+ matrix,
+ eigvals,
+ eigvecs,
+ lobpcg_iters,
)
# The minimal eigenvalue among top-k becomes the maximal one in the whole
@@ -628,7 +676,8 @@ def matrix_inverse_pth_root(
# Deflate out top eigenvectors to reduce matrix condition number.
matrix -= scaled_vecs.dot(
- scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
+ scaled_vecs.T, precision=jax.lax.Precision.HIGHEST
+ )
if relative_matrix_epsilon:
if eigvals is not None:
@@ -637,11 +686,12 @@ def matrix_inverse_pth_root(
# Only use power iteration if lobpcg wasn't already used to derive the
# top eigenvalue.
_, max_ev = power_iteration(
- matrix=matrix,
- num_iters=100,
- error_tolerance=1e-6,
- precision=precision,
- padding_start=padding_start)
+ matrix=matrix,
+ num_iters=100,
+ error_tolerance=1e-6,
+ precision=precision,
+ padding_start=padding_start,
+ )
else:
# Use absolute matrix epsilon scaling otherwise.
max_ev = 1.0
@@ -654,8 +704,9 @@ def matrix_inverse_pth_root(
def _iter_condition(state):
i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, error_ratio = state
- error_above_threshold = jnp.logical_and(error > error_tolerance,
- error_ratio < max_error_ratio)
+ error_above_threshold = jnp.logical_and(
+ error > error_tolerance, error_ratio < max_error_ratio
+ )
return jnp.logical_and(i < num_iters, error_above_threshold)
def _iter_body(state):
@@ -673,7 +724,6 @@ def _iter_body(state):
iters = 0
error_ratio = 0.0
else:
-
retry_loop_error_threshold = 0.05
num_tries = 6
init_outer_state = tuple([0, identity, 1000.0, 100, 1.0, True])
@@ -691,23 +741,26 @@ def _outer_body_fn(state):
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
init_state = tuple(
- [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0])
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0]
+ )
iters, mat_m, mat_h, old_mat_h, error, error_ratio = lax.while_loop(
- _iter_condition, _iter_body, init_state)
+ _iter_condition, _iter_body, init_state
+ )
error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
is_converged = jnp.asarray(error_ratio < max_error_ratio, old_mat_h.dtype)
- resultant_mat_h = is_converged * \
- mat_h + (1 - is_converged) * old_mat_h
- return (i + 1,
- resultant_mat_h,
- error,
- iters,
- error_ratio,
- error > retry_loop_error_threshold)
-
- loop_outputs = jax.lax.while_loop(_outer_iter_condition_fn,
- _outer_body_fn,
- init_outer_state)
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
+ return (
+ i + 1,
+ resultant_mat_h,
+ error,
+ iters,
+ error_ratio,
+ error > retry_loop_error_threshold,
+ )
+
+ loop_outputs = jax.lax.while_loop(
+ _outer_iter_condition_fn, _outer_body_fn, init_outer_state
+ )
total_retries, resultant_mat_h, error, iters, error_ratio, _ = loop_outputs
conditioned_resultant_mat = resultant_mat_h
@@ -723,35 +776,39 @@ def _outer_body_fn(state):
pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p)
scaled_vecs = eigvecs * jnp.sqrt(pth_diff)
resultant_mat_h = conditioned_resultant_mat - scaled_vecs.dot(
- scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
+ scaled_vecs.T, precision=jax.lax.Precision.HIGHEST
+ )
error_metrics = TrainingMetrics(
- inverse_pth_root_errors=jnp.array(error, jnp.float32),
- inverse_pth_root_iters=jnp.array(iters, jnp.float32),
- final_error_ratio=jnp.array(error_ratio, jnp.float32),
- max_eigen_value=jnp.array(max_ev, jnp.float32),
- total_retries=jnp.array(total_retries, jnp.float32))
+ inverse_pth_root_errors=jnp.array(error, jnp.float32),
+ inverse_pth_root_iters=jnp.array(iters, jnp.float32),
+ final_error_ratio=jnp.array(error_ratio, jnp.float32),
+ max_eigen_value=jnp.array(max_ev, jnp.float32),
+ total_retries=jnp.array(total_retries, jnp.float32),
+ )
if lobpcg_topk_precondition > 0:
- damped_matrix = matrix + \
- (ridge_epsilon * (10**total_retries) * identity)
+ damped_matrix = matrix + (ridge_epsilon * (10**total_retries) * identity)
conditioned_diagnostics = InversePthRootDiagnostics.create(
- conditioned_resultant_mat, damped_matrix, p)
+ conditioned_resultant_mat, damped_matrix, p
+ )
unconditioned_damped_matrix = original_matrix + ridge_epsilon * identity
unconditioned_diagnostics = InversePthRootDiagnostics.create(
- resultant_mat_h, unconditioned_damped_matrix, p)
+ resultant_mat_h, unconditioned_damped_matrix, p
+ )
# The max entrywise error in error_metrics.inverse_pth_root_errors refers
# to what was derived from the inverse pth root iteration, which with
# LOBPCG refers to the conditioned problem. Make sure to use the error
# from the unconditioned problem.
unconditional_errors = jnp.maximum(
- unconditioned_diagnostics.max_diag_error,
- unconditioned_diagnostics.max_off_diag_error)
+ unconditioned_diagnostics.max_diag_error,
+ unconditioned_diagnostics.max_off_diag_error,
+ )
error_metrics = error_metrics.replace(
- inverse_pth_root_errors=unconditional_errors,
- lobpcg_diagnostics=lobpcg_diagnostics,
- conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics,
- inverse_pth_root_diagnostics=unconditioned_diagnostics,
+ inverse_pth_root_errors=unconditional_errors,
+ lobpcg_diagnostics=lobpcg_diagnostics,
+ conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics,
+ inverse_pth_root_diagnostics=unconditioned_diagnostics,
)
if padding_start is not None:
@@ -759,9 +816,9 @@ def _outer_body_fn(state):
# due to some TPU hosts not having the same number of preconditioning
# matrices.
resultant_mat_h = jnp.where(padding_start == 0, 0.0, resultant_mat_h)
- error = jnp.where(padding_start == 0,
- 0.0,
- error_metrics.inverse_pth_root_errors)
+ error = jnp.where(
+ padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors
+ )
error_metrics = error_metrics.replace(inverse_pth_root_errors=error)
resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
@@ -769,44 +826,44 @@ def _outer_body_fn(state):
def matrix_inverse_pth_root_eigh(
- matrix,
- p,
- ridge_epsilon=1e-6,
- error_tolerance=1e-6,
- precision=lax.Precision.HIGHEST,
- relative_matrix_epsilon=True,
- padding_start=None,
- prev=None,
+ matrix,
+ p,
+ ridge_epsilon=1e-6,
+ error_tolerance=1e-6,
+ precision=lax.Precision.HIGHEST,
+ relative_matrix_epsilon=True,
+ padding_start=None,
+ prev=None,
):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
- This function uses eigh for the computation of a matrix's inverse pth
- root.
-
- Args:
- matrix: the symmetric PSD matrix whose power it to be computed
- p: exponent, for p a positive integer.
- ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
- error_tolerance: Error indicator, useful for early termination.
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
- value when computing inverse-pth root.
- padding_start: If the input matrix was padded, then zeros out columns and
- rows at the padding start.
- prev: previous iteration's solution, zero-padded (unused)
-
- Returns:
- `(matrix + eps)^(-1/p)` and error metrics.
-
- Note `eps` is not added to zeroed out padding rows and
- columns. `eps` is just `ridge_epsilon` if
- `relative_matrix_epsilon` is set to `False`, otherwise, it is the
- ridge epsilon value scaled by the derived maximum eigenvalue of
- the input matrix.
- """
+ This function uses eigh for the computation of a matrix's inverse pth
+ root.
+
+ Args:
+ matrix: the symmetric PSD matrix whose power it to be computed
+ p: exponent, for p a positive integer.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+ error_tolerance: Error indicator, useful for early termination.
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
+ value when computing inverse-pth root.
+ padding_start: If the input matrix was padded, then zeros out columns and
+ rows at the padding start.
+ prev: previous iteration's solution, zero-padded (unused)
+
+ Returns:
+ `(matrix + eps)^(-1/p)` and error metrics.
+
+ Note `eps` is not added to zeroed out padding rows and
+ columns. `eps` is just `ridge_epsilon` if
+ `relative_matrix_epsilon` is set to `False`, otherwise, it is the
+ ridge epsilon value scaled by the derived maximum eigenvalue of
+ the input matrix.
+ """
del prev
assert matrix.shape[0] == matrix.shape[1]
matrix_size = matrix.shape[0]
@@ -816,17 +873,19 @@ def matrix_inverse_pth_root_eigh(
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
- matrix.dtype)
+ matrix.dtype
+ )
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
if relative_matrix_epsilon:
_, max_ev = power_iteration(
- matrix=matrix,
- num_iters=100,
- error_tolerance=error_tolerance,
- precision=precision,
- padding_start=padding_start)
+ matrix=matrix,
+ num_iters=100,
+ error_tolerance=error_tolerance,
+ precision=precision,
+ padding_start=padding_start,
+ )
else:
# Use absolute matrix epsilon scaling otherwise.
max_ev = 1.0
@@ -837,9 +896,9 @@ def matrix_inverse_pth_root_eigh(
if padding_start is not None:
e *= jnp.flip(ix)
mm = functools.partial(jnp.matmul, precision=precision)
- inv_e = jnp.where(e == 0.0,
- 0.0,
- jnp.power(jnp.maximum(e, ridge_epsilon), alpha))
+ inv_e = jnp.where(
+ e == 0.0, 0.0, jnp.power(jnp.maximum(e, ridge_epsilon), alpha)
+ )
val = mm(mm(u, jnp.diag(inv_e)), u.T)
root = u * jnp.sqrt(inv_e)
val = mm(root, root.T)
@@ -849,12 +908,13 @@ def matrix_inverse_pth_root_eigh(
eig_error *= jnp.flip(ix)
error = jnp.max(jnp.abs(eig_error))
error_metrics = TrainingMetrics(
- inverse_pth_root_errors=jnp.array(error, jnp.float32))
+ inverse_pth_root_errors=jnp.array(error, jnp.float32)
+ )
if padding_start is not None:
val = jnp.where(padding_start == 0, 0.0, val)
- error = jnp.where(padding_start == 0,
- 0.0,
- error_metrics.inverse_pth_root_errors)
+ error = jnp.where(
+ padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors
+ )
error_metrics = error_metrics.replace(inverse_pth_root_errors=error)
val = jnp.asarray(val, orig_dtype)
return val, error_metrics
@@ -863,17 +923,17 @@ def matrix_inverse_pth_root_eigh(
def merge_small_dims(shape_to_merge, max_dim):
"""Merge small dimensions.
- If there are some small dimensions, we collapse them:
- e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
- [1, 2, 768, 1, 2048] --> [2, 768, 2048]
+ If there are some small dimensions, we collapse them:
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
- Args:
- shape_to_merge: Shape to merge small dimensions.
- max_dim: Maximal dimension of output shape used in merging.
+ Args:
+ shape_to_merge: Shape to merge small dimensions.
+ max_dim: Maximal dimension of output shape used in merging.
- Returns:
- Merged shape.
- """
+ Returns:
+ Merged shape.
+ """
if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
return [1]
@@ -894,20 +954,23 @@ def merge_small_dims(shape_to_merge, max_dim):
def pad_square_matrix(mat, max_size):
"""Pad a square matrix up to max_size.
- Args:
- mat: a matrix to pad.
- max_size: matrix size requested.
+ Args:
+ mat: a matrix to pad.
+ max_size: matrix size requested.
- Returns:
- Given M returns [[M, 0], [0, I]]
- """
+ Returns:
+ Given M returns [[M, 0], [0, I]]
+ """
rows, cols = mat.shape
if rows != cols:
- raise ValueError("Must have rows == cols, instead got "
- f"rows={rows}, cols={cols}")
+ raise ValueError(
+ f'Must have rows == cols, instead got rows={rows}, cols={cols}'
+ )
if cols > max_size:
- raise ValueError("Must have cols <= max_size. Instead got "
- f"cols={cols}, max_size={max_size}.")
+ raise ValueError(
+ 'Must have cols <= max_size. Instead got '
+ f'cols={cols}, max_size={max_size}.'
+ )
if rows == max_size:
return mat
pad_size = max_size - rows
@@ -923,13 +986,13 @@ def pad_square_matrix(mat, max_size):
def pad_vector(vec, max_size):
"""Pad a vector to a max_size.
- Args:
- vec: a vector to pad.
- max_size: matrix size requested.
+ Args:
+ vec: a vector to pad.
+ max_size: matrix size requested.
- Returns:
- Given V returns [V, 0]
- """
+ Returns:
+ Given V returns [V, 0]
+ """
size = vec.shape[0]
assert size <= max_size
if size == max_size:
@@ -949,9 +1012,9 @@ def _iter_body(unused_state):
def _iter_condition(state):
return state[0]
- results = jax.lax.while_loop(_iter_condition,
- _iter_body,
- tuple([predicate] + init_state))
+ results = jax.lax.while_loop(
+ _iter_condition, _iter_body, tuple([predicate] + init_state)
+ )
return tuple(results[1:])
@@ -985,7 +1048,7 @@ def partition(self, tensor):
assert tensor.shape == self._shape
tensors = [tensor]
- for (i, indices) in self._splits:
+ for i, indices in self._splits:
tensors_local = []
for t in tensors:
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
@@ -995,13 +1058,14 @@ def partition(self, tensor):
def merge_partitions(self, partitions):
"""Merge partitions back to original shape."""
- for (i, indices) in reversed(self._splits):
+ for i, indices in reversed(self._splits):
n = len(indices) + 1
partial_merged_tensors = []
ind = 0
while ind < len(partitions):
partial_merged_tensors.append(
- jnp.concatenate(partitions[ind:ind + n], axis=i))
+ jnp.concatenate(partitions[ind : ind + n], axis=i)
+ )
ind += n
partitions = partial_merged_tensors
assert len(partitions) == 1
@@ -1011,25 +1075,25 @@ def merge_partitions(self, partitions):
def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None):
"""Updated statistics via weighted average with new Gram matrix.
- Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose
- columns are the flattened slices of the tensor `g` along the given `axis`.
- (So, `old_stats` and the returned matrix have dimensions n x n where
- n = `g.shape[axis]`).
-
- Args:
- old_stats: Old statistics.
- g: Gradient tensor.
- axis: Axis along which to slice `g`.
- w1: Scalar weight for old statistics.
- w2: Scalar weight for new Gram matrix.
- precision: Optional precision XLA related flag, the available options are:
- a) lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
-
- Returns:
- Weighted average of old and new statistics.
- """
+ Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose
+ columns are the flattened slices of the tensor `g` along the given `axis`.
+ (So, `old_stats` and the returned matrix have dimensions n x n where
+ n = `g.shape[axis]`).
+
+ Args:
+ old_stats: Old statistics.
+ g: Gradient tensor.
+ axis: Axis along which to slice `g`.
+ w1: Scalar weight for old statistics.
+ w2: Scalar weight for new Gram matrix.
+ precision: Optional precision XLA related flag, the available options are:
+ a) lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+
+ Returns:
+ Weighted average of old and new statistics.
+ """
axes = [i for i in range(g.ndim) if i != axis]
gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision)
return w1 * old_stats + w2 * gram_matrix
@@ -1039,67 +1103,68 @@ class Preconditioner:
"""Compute statistics/shape from gradients for preconditioning."""
def __init__(
- self,
- param,
- block_size,
- merge_small_dims_block_size,
- best_effort_shape_interpretation,
- preconditioner_type=PreconditionerType.ALL,
+ self,
+ param,
+ block_size,
+ merge_small_dims_block_size,
+ best_effort_shape_interpretation,
+ preconditioner_type=PreconditionerType.ALL,
):
"""Initializes the preconditioner.
- Args:
- param: parameter to precondition.
- block_size: Block size used to split param.
- merge_small_dims_block_size: Block size for merging dims.
- best_effort_shape_interpretation: Whether to
- collapse/merge dims together.
- preconditioner_type: Type of preconditioner to use.
- """
+ Args:
+ param: parameter to precondition.
+ block_size: Block size used to split param.
+ merge_small_dims_block_size: Block size for merging dims.
+ best_effort_shape_interpretation: Whether to
+ collapse/merge dims together.
+ preconditioner_type: Type of preconditioner to use.
+ """
self._original_shape = param.shape
self._transformed_shape = param.shape
if best_effort_shape_interpretation:
- self._transformed_shape = merge_small_dims(self._original_shape,
- merge_small_dims_block_size)
+ self._transformed_shape = merge_small_dims(
+ self._original_shape, merge_small_dims_block_size
+ )
reshaped_param = jnp.reshape(param, self._transformed_shape)
self._partitioner = BlockPartitioner(reshaped_param, block_size)
self._preconditioner_type = preconditioner_type
def updated_statistics_from_grad(
- self,
- stats,
- grad,
- w1,
- w2,
- to_float=None,
- from_float=None,
- precision=None,
+ self,
+ stats,
+ grad,
+ w1,
+ w2,
+ to_float=None,
+ from_float=None,
+ precision=None,
):
"""Update statistics from gradients.
- Args:
- stats: Old statistics or its Cholesky factor if `cholesky` is True.
- grad: Gradient to compute statistics from.
- w1: Weight for old statistics.
- w2: Weight for new statistics.
- to_float: Optional function for converting stats to floating point.
- from_float: Optional function for converting from floating point.
- precision: Optional precision XLA related flag, the available options
- are:
- a) lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
-
- Returns:
- A list of updated gradient statistics for each partition.
- """
+ Args:
+ stats: Old statistics or its Cholesky factor if `cholesky` is True.
+ grad: Gradient to compute statistics from.
+ w1: Weight for old statistics.
+ w2: Weight for new statistics.
+ to_float: Optional function for converting stats to floating point.
+ from_float: Optional function for converting from floating point.
+ precision: Optional precision XLA related flag, the available options
+ are:
+ a) lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+
+ Returns:
+ A list of updated gradient statistics for each partition.
+ """
to_float = to_float if to_float is not None else (lambda x: x)
from_float = from_float if from_float is not None else (lambda x: x)
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
should_preconditioned_dims = self.should_precondition_dims()
preconditioned_dims = [
- i for i, p in enumerate(should_preconditioned_dims) if p
+ i for i, p in enumerate(should_preconditioned_dims) if p
]
new_stats = []
index = 0
@@ -1136,8 +1201,7 @@ def _preconds_for_grad(self, preconditioners, rank, start, end):
elif self._preconditioner_type == PreconditionerType.OUTPUT:
# When _preconditioner_type is OUTPUT, we append (rank - 1) many None
# values to the beginning of the list to handle the False indices.
- preconditioners_for_grad = [None] * \
- (rank - 1) + preconditioners_for_grad
+ preconditioners_for_grad = [None] * (rank - 1) + preconditioners_for_grad
assert len(preconditioners_for_grad) == rank
return preconditioners_for_grad
@@ -1165,13 +1229,13 @@ def exponent_for_preconditioner(self):
def preconditioned_grad(self, grad, preconditioners):
"""Precondition the gradient.
- Args:
- grad: A gradient tensor to precondition.
- preconditioners: A list of preconditioners to apply.
+ Args:
+ grad: A gradient tensor to precondition.
+ preconditioners: A list of preconditioners to apply.
- Returns:
- A preconditioned gradient.
- """
+ Returns:
+ A preconditioned gradient.
+ """
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
should_preconditioned_dims = self.should_precondition_dims()
@@ -1179,17 +1243,18 @@ def preconditioned_grad(self, grad, preconditioners):
preconditioned_partitioned_grads = []
for i, g in enumerate(partitioned_grads):
preconditioners_for_grad = self._preconds_for_grad(
- preconditioners,
- rank=len(should_preconditioned_dims),
- start=i * num_preconditioners,
- end=(i + 1) * num_preconditioners,
+ preconditioners,
+ rank=len(should_preconditioned_dims),
+ start=i * num_preconditioners,
+ end=(i + 1) * num_preconditioners,
+ )
+ precond_g = self._precondition_block(
+ g, should_preconditioned_dims, preconditioners_for_grad
)
- precond_g = self._precondition_block(g,
- should_preconditioned_dims,
- preconditioners_for_grad)
preconditioned_partitioned_grads.append(precond_g)
merged_grad = self._partitioner.merge_partitions(
- preconditioned_partitioned_grads)
+ preconditioned_partitioned_grads
+ )
return jnp.reshape(merged_grad, self._original_shape)
def _precondition_block(self, g, should_precondition_dim, preconditioners):
@@ -1208,9 +1273,9 @@ def _precondition_block(self, g, should_precondition_dim, preconditioners):
return g
-def _convert_to_parameter_stats(global_stats,
- local_stat,
- convert_statistics=True):
+def _convert_to_parameter_stats(
+ global_stats, local_stat, convert_statistics=True
+):
"""Creates parameter stats from sharded stats."""
index_start = int(local_stat.index_start)
index_end = int(len(local_stat.sizes)) + index_start
@@ -1225,24 +1290,24 @@ def _convert_to_parameter_stats(global_stats,
if not convert_statistics:
new_statistics = None
return ParameterStats(
- local_stat.diagonal_statistics,
- new_statistics,
- new_preconditioners,
- local_stat.diagonal_momentum,
- local_stat.momentum,
- local_stat.training_metrics,
+ local_stat.diagonal_statistics,
+ new_statistics,
+ new_preconditioners,
+ local_stat.diagonal_momentum,
+ local_stat.momentum,
+ local_stat.training_metrics,
)
def _convert_from_parameter_stats(parameter_stats, local_stats):
"""Creates sharded stats from paramter stats."""
return LocalShardedParameterStats(
- parameter_stats.diagonal_statistics,
- parameter_stats.diagonal_momentum,
- parameter_stats.momentum,
- parameter_stats.training_metrics,
- local_stats.index_start,
- local_stats.sizes,
+ parameter_stats.diagonal_statistics,
+ parameter_stats.diagonal_momentum,
+ parameter_stats.momentum,
+ parameter_stats.training_metrics,
+ local_stats.index_start,
+ local_stats.sizes,
)
@@ -1258,12 +1323,13 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old):
# root calculation to find a new preconditioner, so that TensorBoard curves
# look consistent (otherwise they'd oscillate between NaN and measured
# values).
- per_stat_metrics = efficient_cond(keep_old,
- lambda: [local_stat.training_metrics],
- [per_stat_metrics])[0]
+ per_stat_metrics = efficient_cond(
+ keep_old, lambda: [local_stat.training_metrics], [per_stat_metrics]
+ )[0]
# pylint:enable=cell-var-from-loop
new_local_stats.append(
- local_stat.replace(training_metrics=per_stat_metrics))
+ local_stat.replace(training_metrics=per_stat_metrics)
+ )
return new_local_stats
@@ -1271,7 +1337,7 @@ def batch(x, num_devices):
"""Batch `x` so that so that leading axis is num_devices."""
n = len(x)
b = int(n / num_devices)
- return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
+ return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
def unbatch(batched_values):
@@ -1290,162 +1356,168 @@ def unbatch(batched_values):
def distributed_shampoo(
- learning_rate,
- block_size=1024,
- beta1=0.9,
- beta2=0.999,
- diagonal_epsilon=1e-8,
- matrix_epsilon=1e-6,
- weight_decay=0.0,
- start_preconditioning_step=101,
- preconditioning_compute_steps=20,
- statistics_compute_steps=1,
- best_effort_shape_interpretation=True,
- graft_type=GraftingType.RMSPROP_NORMALIZED,
- nesterov=True,
- exponent_override=0,
- # Pass pmap 'batch axis name' in pmap mode.
- batch_axis_name=None,
- # Only set following 3 params in pjit/spmd mode.
- # WARNING: Experimental
- statistics_partition_spec=None,
- preconditioner_partition_spec=None,
- num_devices_for_pjit=None,
- shard_optimizer_states=False,
- ###
- # Experimental memory reduction mode
- best_effort_memory_usage_reduction=True,
- ###
- inverse_failure_threshold=0.1,
- moving_average_for_momentum=True,
- skip_preconditioning_dim_size_gt=0,
- clip_by_scaled_gradient_norm=None,
- precision=lax.Precision.HIGHEST,
- tensordot_precision=None,
- relative_matrix_epsilon=True,
- merge_small_dims_block_size=4096,
- lobpcg_topk_precondition=0,
- lobpcg_max_iter=0,
- precondtioner_type=PreconditionerType.ALL,
- custom_preconditioner=False,
- skip_preconditioning_rank_lt=1,
- decoupled_learning_rate=True,
- decoupled_weight_decay=False,
- generate_training_metrics=True,
- reuse_preconditioner=False,
- eigh=True,
+ learning_rate,
+ block_size=1024,
+ beta1=0.9,
+ beta2=0.999,
+ diagonal_epsilon=1e-8,
+ matrix_epsilon=1e-6,
+ weight_decay=0.0,
+ start_preconditioning_step=101,
+ preconditioning_compute_steps=20,
+ statistics_compute_steps=1,
+ best_effort_shape_interpretation=True,
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
+ nesterov=True,
+ exponent_override=0,
+ # Pass pmap 'batch axis name' in pmap mode.
+ batch_axis_name=None,
+ # Only set following 3 params in pjit/spmd mode.
+ # WARNING: Experimental
+ statistics_partition_spec=None,
+ preconditioner_partition_spec=None,
+ num_devices_for_pjit=None,
+ shard_optimizer_states=False,
+ ###
+ # Experimental memory reduction mode
+ best_effort_memory_usage_reduction=True,
+ ###
+ inverse_failure_threshold=0.1,
+ moving_average_for_momentum=True,
+ skip_preconditioning_dim_size_gt=0,
+ clip_by_scaled_gradient_norm=None,
+ precision=lax.Precision.HIGHEST,
+ tensordot_precision=None,
+ relative_matrix_epsilon=True,
+ merge_small_dims_block_size=4096,
+ lobpcg_topk_precondition=0,
+ lobpcg_max_iter=0,
+ precondtioner_type=PreconditionerType.ALL,
+ custom_preconditioner=False,
+ skip_preconditioning_rank_lt=1,
+ decoupled_learning_rate=True,
+ decoupled_weight_decay=False,
+ generate_training_metrics=True,
+ reuse_preconditioner=False,
+ eigh=True,
):
"""Distributed Shampoo optimizer.
- Distributed Shampoo is a second-order preconditioned method (concretely, a
- variant of full-matrix Adagrad), that provides significant convergence and
- wall-clock time improvements compared to conventional first-order methods,
- and that has been shown to scale to large state-of-the-art deep learning
- models.
-
- References:
- Scalable Second Order Optimization for Deep Learning,
- Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
-
- Preprint: https://arxiv.org/abs/2002.09018
-
- Args:
- learning_rate: the step size used to update the parameters.
- block_size: Block size for large layers (if > 0). Preconditioning compute
- operation is cubic in the dimension of the tensor. Block size allows us
- to chunk the layers into sub-layers of maximal dimension dictated by
- this value. Use 128 as default (increase if you have compute budget).
- beta1: momentum parameter.
- beta2: second moment averaging parameter.
- diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
- to AdaGrad is enabled).
- matrix_epsilon: epsilon to add to statistics before computing inverse pth
- root. If you are running in f32 precision for inverse pth root
- (recommended today) this can go upto 1e-6. If you have latest hardware
- with native f64 precision, set this upto 1e-12.
- weight_decay: Weight decay for regularization.
- start_preconditioning_step: When to start Shampoo update before which
- diagonal update is used. This is because we dont have enough information
- to do stable inverse.
- preconditioning_compute_steps: How often to compute preconditioner.
- Performance tuning params for controlling memory and compute
- requirements.
- Ideally set this and statistics_compute_steps params to 1.
- statistics_compute_steps: How often to compute statistics.
- best_effort_shape_interpretation: If there are some small dimensions,
- collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
- block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
- graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
- optimizer. This allows us to plugin the Shampoo optimizer into settings
- where SGD/AdaGrad is already well tuned.
- nesterov: Nesterov momentum.
- exponent_override: Override the exponent used in matrix inverse.
- batch_axis_name: labeled axis over pmap for data-parallel training the
- optimizer used for.
- statistics_partition_spec: PartitionSpec to be used in sharded mode.
- preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
- num_devices_for_pjit: Number of devices to parallelize over when using
- pjit.
- shard_optimizer_states: Shard optimizer states to save memory in model
- parallel training.
- best_effort_memory_usage_reduction: Best effort memory usage reduction. -
- diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x)
- -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals
- inverse_failure_threshold: numerics are hard and inverses fail sometimes;
- we determine that using this threshold.
- moving_average_for_momentum: Whether to use moving average for momentum
- instead of exponential moving average.
- skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
- greater than this value.
- clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
- when using RMSProp Grafting).
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- tensordot_precision: Optional precision to use for the tensordot operation
- when computing statistics (e.g., G Gᵀ). Same options as `precision`
- above.
- relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
- value when computing inverse-pth root.
- merge_small_dims_block_size: Used as the maximum block size to merge the
- shapes.
- lobpcg_topk_precondition: If nonzero, specifies the number of top
- eigenvectors to subtract out before performing LOBPCG. Note this makes
- relative_matrix_epsilon essentially free.
- lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to
- `lobpcg_topk_precondition`.
- precondtioner_type: Preconditioner type to select all, left only or right
- only preconditioners.
- skip_preconditioning_rank_lt: Skips preconditioning for parameters with
- rank less than this value.
- decoupled_learning_rate: If True, use decoupled learning rate, otherwise
- couple it with preconditioned gradient computation. (Default True)
- decoupled_weight_decay: If True, use decoupled weight decay, otherwise
- couple with weight decay. (Default False)
- generate_training_metrics: If True, gather training metrics, otherwise
- avoid generating them (to reduce memory usage).
- reuse_preconditioner: If True, pass the previous derived preconditioner
- as a warm start to the next iteratin's inverse pth root computation.
- eigh: If True, and uses eigen decomposition for inverse-pth root.
-
- Returns:
- a GradientTransformation.
- """
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
+ variant of full-matrix Adagrad), that provides significant convergence and
+ wall-clock time improvements compared to conventional first-order methods,
+ and that has been shown to scale to large state-of-the-art deep learning
+ models.
+
+ References:
+ Scalable Second Order Optimization for Deep Learning,
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
+
+ Preprint: https://arxiv.org/abs/2002.09018
+
+ Args:
+ learning_rate: the step size used to update the parameters.
+ block_size: Block size for large layers (if > 0). Preconditioning compute
+ operation is cubic in the dimension of the tensor. Block size allows us
+ to chunk the layers into sub-layers of maximal dimension dictated by
+ this value. Use 128 as default (increase if you have compute budget).
+ beta1: momentum parameter.
+ beta2: second moment averaging parameter.
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
+ to AdaGrad is enabled).
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
+ root. If you are running in f32 precision for inverse pth root
+ (recommended today) this can go upto 1e-6. If you have latest hardware
+ with native f64 precision, set this upto 1e-12.
+ weight_decay: Weight decay for regularization.
+ start_preconditioning_step: When to start Shampoo update before which
+ diagonal update is used. This is because we dont have enough information
+ to do stable inverse.
+ preconditioning_compute_steps: How often to compute preconditioner.
+ Performance tuning params for controlling memory and compute
+ requirements.
+ Ideally set this and statistics_compute_steps params to 1.
+ statistics_compute_steps: How often to compute statistics.
+ best_effort_shape_interpretation: If there are some small dimensions,
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
+ where SGD/AdaGrad is already well tuned.
+ nesterov: Nesterov momentum.
+ exponent_override: Override the exponent used in matrix inverse.
+ batch_axis_name: labeled axis over pmap for data-parallel training the
+ optimizer used for.
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
+ num_devices_for_pjit: Number of devices to parallelize over when using
+ pjit.
+ shard_optimizer_states: Shard optimizer states to save memory in model
+ parallel training.
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x)
+ -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes;
+ we determine that using this threshold.
+ moving_average_for_momentum: Whether to use moving average for momentum
+ instead of exponential moving average.
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
+ greater than this value.
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
+ when using RMSProp Grafting).
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ tensordot_precision: Optional precision to use for the tensordot operation
+ when computing statistics (e.g., G Gᵀ). Same options as `precision`
+ above.
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
+ value when computing inverse-pth root.
+ merge_small_dims_block_size: Used as the maximum block size to merge the
+ shapes.
+ lobpcg_topk_precondition: If nonzero, specifies the number of top
+ eigenvectors to subtract out before performing LOBPCG. Note this makes
+ relative_matrix_epsilon essentially free.
+ lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to
+ `lobpcg_topk_precondition`.
+ precondtioner_type: Preconditioner type to select all, left only or right
+ only preconditioners.
+ skip_preconditioning_rank_lt: Skips preconditioning for parameters with
+ rank less than this value.
+ decoupled_learning_rate: If True, use decoupled learning rate, otherwise
+ couple it with preconditioned gradient computation. (Default True)
+ decoupled_weight_decay: If True, use decoupled weight decay, otherwise
+ couple with weight decay. (Default False)
+ generate_training_metrics: If True, gather training metrics, otherwise
+ avoid generating them (to reduce memory usage).
+ reuse_preconditioner: If True, pass the previous derived preconditioner
+ as a warm start to the next iteratin's inverse pth root computation.
+ eigh: If True, and uses eigen decomposition for inverse-pth root.
+
+ Returns:
+ a GradientTransformation.
+ """
reset_frequency = None
def _graft_type_has_diagonal_statistics():
"""Returns True if using diagonal firt order method for grafting."""
return graft_type not in [
- GraftingType.SGD, GraftingType.SQRT_N, GraftingType.NONE
+ GraftingType.SGD,
+ GraftingType.SQRT_N,
+ GraftingType.NONE,
]
def quantized_dtype_for_momentum_buffers(var):
- return jnp.int8 if best_effort_memory_usage_reduction and len(
- var.shape) > 1 else jnp.float32
+ return (
+ jnp.int8
+ if best_effort_memory_usage_reduction and len(var.shape) > 1
+ else jnp.float32
+ )
quantize_second_moment = (
- best_effort_memory_usage_reduction and batch_axis_name)
+ best_effort_memory_usage_reduction and batch_axis_name
+ )
# Preconditioner and statistics are both stores as int16 in this mode.
# We take out the diagonal to make quantization easier.
@@ -1472,19 +1544,20 @@ def _to_float(maybe_quantized):
def preconditioner_from_params(param):
"""Returns a Preconditioner object for given param."""
return Preconditioner(
- param,
- block_size,
- merge_small_dims_block_size,
- best_effort_shape_interpretation,
- precondtioner_type,
+ param,
+ block_size,
+ merge_small_dims_block_size,
+ best_effort_shape_interpretation,
+ precondtioner_type,
)
def precond_dim(max_size):
"""Derives largest preconditioner dimension."""
return max_size
- def pad_and_maybe_zero_preconditioners(preconditioners, total, max_size,
- step):
+ def pad_and_maybe_zero_preconditioners(
+ preconditioners, total, max_size, step
+ ):
"""Pad preconditioners up to total x max_size x precond_dim(max_size)."""
pd = precond_dim(max_size)
@@ -1513,9 +1586,9 @@ def _pad_preconditioner(preconditioner):
def sharded_init_fn(params):
"""Returns optimizer state (for PJIT mode).
- Args:
- params: the parameters that should be updated.
- """
+ Args:
+ params: the parameters that should be updated.
+ """
params_flat, treedef = jax.tree_util.tree_flatten(params)
# Find max size to pad to.
max_size = 0
@@ -1542,21 +1615,22 @@ def sharded_init_fn(params):
sizes = [s[0] for s in shapes]
shapes = preconditioner.shapes_for_preconditioners()
statistics = [
- matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
- for s in shapes
+ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) for s in shapes
]
pd = precond_dim(max_size)
# If the preconditioner is using a low-rank representation, initialize
# it to zero instead of an invalid eye.
preconditioners = [
- jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size)
- for s in shapes
+ jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size)
+ for s in shapes
]
padded_statistics.extend(statistics)
padded_preconditioners.extend(preconditioners)
exponent = (
- preconditioner.exponent_for_preconditioner()
- if exponent_override == 0 else exponent_override)
+ preconditioner.exponent_for_preconditioner()
+ if exponent_override == 0
+ else exponent_override
+ )
exponents.extend([exponent] * len(shapes))
diagonal_statistics = jnp.zeros_like(param)
@@ -1564,16 +1638,18 @@ def sharded_init_fn(params):
momentum = jnp.zeros_like(param)
local_stats_flat.append(
- LocalShardedParameterStats(
- diagonal_statistics,
- diagonal_momentum,
- momentum,
- init_training_metrics(
- len(sizes),
- generate_training_metrics,
- ),
- index_start,
- sizes))
+ LocalShardedParameterStats(
+ diagonal_statistics,
+ diagonal_momentum,
+ momentum,
+ init_training_metrics(
+ len(sizes),
+ generate_training_metrics,
+ ),
+ index_start,
+ sizes,
+ )
+ )
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
to_pad = -len(padded_statistics) % num_devices_for_pjit
@@ -1588,22 +1664,27 @@ def sharded_init_fn(params):
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
# is split on.
padded_statistics.extend(
- [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)])
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
+ )
pd = precond_dim(max_size)
# If the preconditioner is using a low-rank representation, initialize
# it to zero instead of an invalid eye.
- padded_preconditioners.extend([
+ padded_preconditioners.extend(
+ [
jnp.eye(max_size, pd, dtype=stat_dtype) * (pd == max_size)
for _ in range(to_pad)
- ])
+ ]
+ )
exponents.extend([1 for _ in range(to_pad)])
global_stats = GlobalShardedParameterStats(
- jnp.stack(padded_statistics),
- jnp.stack(padded_preconditioners),
- jnp.stack(exponents))
+ jnp.stack(padded_statistics),
+ jnp.stack(padded_preconditioners),
+ jnp.stack(exponents),
+ )
return ShampooState(
- count=jnp.zeros([], jnp.int32),
- stats=ShardedShampooStats(global_stats, local_stats))
+ count=jnp.zeros([], jnp.int32),
+ stats=ShardedShampooStats(global_stats, local_stats),
+ )
def _max_statistics_size_from_params(params):
max_size = 0
@@ -1624,20 +1705,21 @@ def _remove_leading_sharding_annotation(pspec):
else:
return []
- def sharded_init_partition_spec_fn(params,
- params_partition_spec,
- partition_spec_for_statistics):
+ def sharded_init_partition_spec_fn(
+ params, params_partition_spec, partition_spec_for_statistics
+ ):
"""Returns a parallel state tree with PartitionSpec associated with state.
- Args:
- params: A pytree with params.
- params_partition_spec: A pytree with PartitionSpec for params.
- partition_spec_for_statistics: PartitionSpec for the statistics.
- """
+ Args:
+ params: A pytree with params.
+ params_partition_spec: A pytree with PartitionSpec for params.
+ partition_spec_for_statistics: PartitionSpec for the statistics.
+ """
# Parallel lists of spec, and params.
param_pspec_flat, _ = jax.tree_util.tree_flatten(
- params_partition_spec, is_leaf=lambda x: x is None)
+ params_partition_spec, is_leaf=lambda x: x is None
+ )
params_flat, treedef = jax.tree_util.tree_flatten(params)
assert param_pspec_flat
assert params_flat
@@ -1667,48 +1749,57 @@ def sharded_init_partition_spec_fn(params,
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
local_stats_flat.append(
- LocalShardedParameterStats(
- QuantizedValue(
- param_pspec,
- [],
- [],
- jnp.float32,
- False, # pytype: disable=wrong-arg-types # numpy-scalars
- list(param.shape)),
- QuantizedValue(
- m1_pspec,
- [],
- m1_scale_pspec,
- qdtype,
- False, # pytype: disable=wrong-arg-types # numpy-scalars
- list(param.shape)),
- QuantizedValue(
- m2_pspec,
- [],
- m2_scale_pspec,
- qdtype,
- False, # pytype: disable=wrong-arg-types # numpy-scalars
- list(param.shape)),
- init_training_metrics_pspec(generate_training_metrics,),
- index_start,
- sizes))
+ LocalShardedParameterStats(
+ QuantizedValue(
+ param_pspec,
+ [],
+ [],
+ jnp.float32,
+ False, # pytype: disable=wrong-arg-types # numpy-scalars
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m1_pspec,
+ [],
+ m1_scale_pspec,
+ qdtype,
+ False, # pytype: disable=wrong-arg-types # numpy-scalars
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m2_pspec,
+ [],
+ m2_scale_pspec,
+ qdtype,
+ False, # pytype: disable=wrong-arg-types # numpy-scalars
+ list(param.shape),
+ ),
+ init_training_metrics_pspec(
+ generate_training_metrics,
+ ),
+ index_start,
+ sizes,
+ )
+ )
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
- global_stats = GlobalShardedParameterStats(partition_spec_for_statistics,
- partition_spec_for_statistics,
- jax.sharding.PartitionSpec())
+ global_stats = GlobalShardedParameterStats(
+ partition_spec_for_statistics,
+ partition_spec_for_statistics,
+ jax.sharding.PartitionSpec(),
+ )
count_pspec = jax.sharding.PartitionSpec()
return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars
- count=count_pspec,
- stats=ShardedShampooStats(global_stats, local_stats))
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
+ )
def sharded_init_shape_and_dtype_fn(params):
"""Returns a parallel state tree with shape, dtype associated with state.
- Args:
- params: A pytree with params.
- """
+ Args:
+ params: A pytree with params.
+ """
# Parallel lists of spec, and params.
params_flat, treedef = jax.tree_util.tree_flatten(params)
assert params_flat
@@ -1739,31 +1830,39 @@ def sharded_init_shape_and_dtype_fn(params):
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
local_stats_flat.append(
- LocalShardedParameterStats(
- QuantizedValue(
- diagonal_statistics_shape_and_dtype,
- [],
- [], # pytype: disable=wrong-arg-types # numpy-scalars
- jnp.float32,
- False,
- list(param.shape)),
- QuantizedValue(m1_shape_and_dtype, [],
- m1_scale_shape_and_dtype,
- qdtype,
- False,
- list(param.shape)),
- QuantizedValue(m2_shape_and_dtype, [],
- m2_scale_shape_and_dtype,
- qdtype,
- False,
- list(param.shape)),
- init_training_metrics_shapes(
- len(sizes),
- generate_training_metrics,
- ),
- index_start,
- sizes,
- ))
+ LocalShardedParameterStats(
+ QuantizedValue(
+ diagonal_statistics_shape_and_dtype,
+ [],
+ [], # pytype: disable=wrong-arg-types # numpy-scalars
+ jnp.float32,
+ False,
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m1_shape_and_dtype,
+ [],
+ m1_scale_shape_and_dtype,
+ qdtype,
+ False,
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m2_shape_and_dtype,
+ [],
+ m2_scale_shape_and_dtype,
+ qdtype,
+ False,
+ list(param.shape),
+ ),
+ init_training_metrics_shapes(
+ len(sizes),
+ generate_training_metrics,
+ ),
+ index_start,
+ sizes,
+ )
+ )
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
max_statistics_size = _max_statistics_size_from_params(params_flat)
@@ -1773,29 +1872,36 @@ def sharded_init_shape_and_dtype_fn(params):
num_statistics = num_devices_for_pjit
max_statistics_size = block_size
statistics_shape = [
- num_statistics, max_statistics_size, max_statistics_size
+ num_statistics,
+ max_statistics_size,
+ max_statistics_size,
]
preconditioners_shape = [
- num_statistics, max_statistics_size, precond_dim(max_statistics_size)
+ num_statistics,
+ max_statistics_size,
+ precond_dim(max_statistics_size),
]
global_stats = GlobalShardedParameterStats(
- [statistics_shape, jnp.float32], [preconditioners_shape, jnp.float32],
- [[num_statistics], jnp.int32])
+ [statistics_shape, jnp.float32],
+ [preconditioners_shape, jnp.float32],
+ [[num_statistics], jnp.int32],
+ )
return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars
- count=[[], jnp.float32],
- stats=ShardedShampooStats(global_stats, local_stats))
+ count=[[], jnp.float32],
+ stats=ShardedShampooStats(global_stats, local_stats),
+ )
def sharded_update_fn(grads, state, params):
"""Transform the input gradient and update all statistics in sharded mode.
- Args:
- grads: the gradient tensors for the parameters.
- state: a named tuple containing the state of the optimizer
- params: the parameters that should be updated.
+ Args:
+ grads: the gradient tensors for the parameters.
+ state: a named tuple containing the state of the optimizer
+ params: the parameters that should be updated.
- Returns:
- A tuple containing the new parameters and the new optimizer state.
- """
+ Returns:
+ A tuple containing the new parameters and the new optimizer state.
+ """
params_flat, treedef = jax.tree_util.tree_flatten(params)
grads_flat = treedef.flatten_up_to(grads)
@@ -1803,43 +1909,45 @@ def sharded_update_fn(grads, state, params):
local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
stats_flat = []
for local_stat in local_stats_flat:
- stats_flat.append(_convert_to_parameter_stats(
+ stats_flat.append(
+ _convert_to_parameter_stats(
global_stats,
local_stat,
- ))
+ )
+ )
new_stats_flat = jax.tree.map(
- lambda g,
- s,
- p: _compute_stats(g, s, p, state.count),
- grads_flat,
- stats_flat,
- params_flat)
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
+ grads_flat,
+ stats_flat,
+ params_flat,
+ )
outputs = jax.tree.map(
- lambda g,
- s,
- p: _transform_grad(g, s, p, state.count),
- grads_flat,
- new_stats_flat,
- params_flat)
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
+ grads_flat,
+ new_stats_flat,
+ params_flat,
+ )
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
updates = jax.tree_util.tree_unflatten(treedef, updates_flat)
new_local_stats_flat = []
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat):
new_local_stats_flat.append(
- _convert_from_parameter_stats(
- new_stat,
- local_stat,
- ))
+ _convert_from_parameter_stats(
+ new_stat,
+ local_stat,
+ )
+ )
max_size = global_stats.statistics.shape[1]
new_padded_statistics = []
padding_starts = []
for stat in new_stats_flat:
new_padded_statistics.extend(
- [pad_square_matrix(stat, max_size) for stat in stat.statistics])
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
+ )
padding_starts.extend([len(stat) for stat in stat.statistics])
# Create global stats
@@ -1857,7 +1965,8 @@ def sharded_update_fn(grads, state, params):
stat_dtype = new_padded_statistics[0].dtype
new_padded_statistics.extend(
- [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)])
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
+ )
padding_starts += [0] * to_pad
if reuse_preconditioner:
@@ -1865,29 +1974,30 @@ def sharded_update_fn(grads, state, params):
for stat in new_stats_flat:
prev_preconditioners.extend(stat.preconditioners)
prev_padded_preconditioners = pad_and_maybe_zero_preconditioners(
- prev_preconditioners,
- len(new_padded_statistics),
- max_size,
- state.count)
+ prev_preconditioners, len(new_padded_statistics), max_size, state.count
+ )
else:
prev_padded_preconditioners = None
new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
new_stacked_padded_statistics = pjit.with_sharding_constraint(
- new_stacked_padded_statistics, statistics_partition_spec)
+ new_stacked_padded_statistics, statistics_partition_spec
+ )
stacked_padding_starts = jnp.array(padding_starts, jnp.int32)
prev_stacked_padded_preconditioners = _maybe(jnp.stack)(
- prev_padded_preconditioners)
+ prev_padded_preconditioners
+ )
prev_stacked_padded_preconditioners = _maybe(pjit.with_sharding_constraint)(
- prev_padded_preconditioners, statistics_partition_spec)
+ prev_padded_preconditioners, statistics_partition_spec
+ )
def _internal_inverse_pth_root_all():
preconditioners, metrics = _matrix_inverse_pth_root_pjit(
- new_stacked_padded_statistics,
- global_stats.exponents,
- stacked_padding_starts,
- prev_stacked_padded_preconditioners,
- statistics_partition_spec,
+ new_stacked_padded_statistics,
+ global_stats.exponents,
+ stacked_padding_starts,
+ prev_stacked_padded_preconditioners,
+ statistics_partition_spec,
)
return preconditioners, metrics
@@ -1903,39 +2013,47 @@ def _internal_inverse_pth_root_all():
preconditioners_init = new_stacked_padded_statistics[:, :, :pd]
n = new_stacked_padded_statistics.shape[0]
metrics_init = cast(
- TrainingMetrics,
- init_training_metrics(
- n,
- generate_training_metrics=True,
- ))
+ TrainingMetrics,
+ init_training_metrics(
+ n,
+ generate_training_metrics=True,
+ ),
+ )
new_errors = jnp.ones_like(metrics_init.inverse_pth_root_errors) * (
- inverse_failure_threshold)
+ inverse_failure_threshold
+ )
metrics_init = metrics_init.replace(inverse_pth_root_errors=new_errors)
init_state = [preconditioners_init, metrics_init]
new_preconditioners, metrics = efficient_cond(
- perform_step, _internal_inverse_pth_root_all, init_state)
+ perform_step, _internal_inverse_pth_root_all, init_state
+ )
if generate_training_metrics:
new_local_stats_flat = _add_metrics_into_local_stats(
- new_local_stats_flat, metrics, ~perform_step)
- new_local_stats = jax.tree_util.tree_unflatten(treedef,
- new_local_stats_flat)
+ new_local_stats_flat, metrics, ~perform_step
+ )
+ new_local_stats = jax.tree_util.tree_unflatten(
+ treedef, new_local_stats_flat
+ )
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
- jnp.isnan(errors),
- errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
+ jnp.isnan(errors), errors >= inverse_failure_threshold
+ ).astype(new_preconditioners.dtype)
# TODO(rohananil): Check for numerical instabilities.
new_conditional_preconditioners = (
- predicate * global_stats.preconditioners +
- (1.0 - predicate) * new_preconditioners)
+ predicate * global_stats.preconditioners
+ + (1.0 - predicate) * new_preconditioners
+ )
new_global_stats = GlobalShardedParameterStats(
- new_stacked_padded_statistics,
- new_conditional_preconditioners,
- global_stats.exponents)
+ new_stacked_padded_statistics,
+ new_conditional_preconditioners,
+ global_stats.exponents,
+ )
new_shampoo_state = ShampooState(
- count=state.count + 1,
- stats=ShardedShampooStats(new_global_stats, new_local_stats))
+ count=state.count + 1,
+ stats=ShardedShampooStats(new_global_stats, new_local_stats),
+ )
return updates, new_shampoo_state
def init_fn(params):
@@ -1948,13 +2066,13 @@ def _init(param):
if not _skip_preconditioning(param):
shapes = preconditioner.shapes_for_preconditioners()
statistics = [
- matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
+ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
]
# If the preconditioner is using a low-rank representation, initialize
# it to zero instead of an invalid eye.
preconditioners = [
- jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1])
- for s in shapes
+ jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1])
+ for s in shapes
]
diagonal_statistics = []
@@ -1967,25 +2085,28 @@ def _init(param):
momentum = jnp.zeros_like(param)
return ParameterStats(
- diagonal_statistics,
- statistics,
- preconditioners,
- # _quantize_diagonal_statistics(diagonal_statistics),
- # _maybe_quantize_statistics(statistics),
- # _maybe_quantize_preconditioners(preconditioners),
- diagonal_momentum,
- momentum,
- init_training_metrics(
- len(statistics),
- generate_training_metrics,
- ))
+ diagonal_statistics,
+ statistics,
+ preconditioners,
+ # _quantize_diagonal_statistics(diagonal_statistics),
+ # _maybe_quantize_statistics(statistics),
+ # _maybe_quantize_preconditioners(preconditioners),
+ diagonal_momentum,
+ momentum,
+ init_training_metrics(
+ len(statistics),
+ generate_training_metrics,
+ ),
+ )
return ShampooState(
- count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params))
+ count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)
+ )
def _skip_preconditioning(param):
return len(param.shape) < skip_preconditioning_rank_lt or any(
- s > skip_preconditioning_dim_size_gt for s in param.shape)
+ s > skip_preconditioning_dim_size_gt for s in param.shape
+ )
def _compute_stats(grad, state, param, step):
"""Compute per-parameter statistics."""
@@ -1997,105 +2118,117 @@ def _compute_stats(grad, state, param, step):
def compute_updated_statistics():
return preconditioner.updated_statistics_from_grad(
- state.statistics,
- grad,
- w1=w1,
- w2=w2,
- to_float=_to_float,
- from_float=lambda x: x,
- # from_float=lambda x: _maybe_quantize_statistics([x])[0],
- precision=tensordot_precision,
+ state.statistics,
+ grad,
+ w1=w1,
+ w2=w2,
+ to_float=_to_float,
+ from_float=lambda x: x,
+ # from_float=lambda x: _maybe_quantize_statistics([x])[0],
+ precision=tensordot_precision,
)
if statistics_compute_steps > 1:
perform_step = step % statistics_compute_steps == 0
init_state = state.statistics
new_statistics = list(
- efficient_cond(perform_step, compute_updated_statistics,
- init_state))
+ efficient_cond(perform_step, compute_updated_statistics, init_state)
+ )
else:
new_statistics = compute_updated_statistics()
- return ParameterStats(state.diagonal_statistics,
- new_statistics,
- state.preconditioners,
- state.diagonal_momentum,
- state.momentum,
- state.training_metrics)
+ return ParameterStats(
+ state.diagonal_statistics,
+ new_statistics,
+ state.preconditioners,
+ state.diagonal_momentum,
+ state.momentum,
+ state.training_metrics,
+ )
mi_pth_root = functools.partial(
- matrix_inverse_pth_root,
- ridge_epsilon=matrix_epsilon,
- precision=precision,
- relative_matrix_epsilon=relative_matrix_epsilon,
- lobpcg_topk_precondition=lobpcg_topk_precondition,
- lobpcg_max_iter=lobpcg_max_iter,
- eigh=eigh)
+ matrix_inverse_pth_root,
+ ridge_epsilon=matrix_epsilon,
+ precision=precision,
+ relative_matrix_epsilon=relative_matrix_epsilon,
+ lobpcg_topk_precondition=lobpcg_topk_precondition,
+ lobpcg_max_iter=lobpcg_max_iter,
+ eigh=eigh,
+ )
def _matrix_inverse_pth_root_vmap(xs, ps, padding_starts, prev):
return jax.vmap(mi_pth_root)(
- xs, ps, padding_start=padding_starts, prev=prev)
+ xs, ps, padding_start=padding_starts, prev=prev
+ )
- def _matrix_inverse_pth_root_pjit(xs,
- ps,
- padding_starts,
- prev_preconds=None,
- statistics_partition_spec=None):
+ def _matrix_inverse_pth_root_pjit(
+ xs, ps, padding_starts, prev_preconds=None, statistics_partition_spec=None
+ ):
# Partition the concatenated statistics matrix across all cores.
pspec_for_partition = preconditioner_partition_spec
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
if preconditioner_partition_spec:
partitioned_ps_spec = jax.sharding.PartitionSpec(
- preconditioner_partition_spec[0])
+ preconditioner_partition_spec[0]
+ )
else:
partitioned_ps_spec = None
partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec)
partitioned_prev_preconds = _maybe(pjit.with_sharding_constraint)(
- prev_preconds, preconditioner_partition_spec)
+ prev_preconds, preconditioner_partition_spec
+ )
partitioned_padding_starts = pjit.with_sharding_constraint(
- padding_starts, partitioned_ps_spec) # paddings are scalars like ps.
+ padding_starts, partitioned_ps_spec
+ ) # paddings are scalars like ps.
# Run matrix inverse pth root on each shard.
partitioned_preconditioners, partitioned_metrics = (
- _matrix_inverse_pth_root_vmap(
- partitioned_xs,
- partitioned_ps,
- partitioned_padding_starts,
- prev=partitioned_prev_preconds))
+ _matrix_inverse_pth_root_vmap(
+ partitioned_xs,
+ partitioned_ps,
+ partitioned_padding_starts,
+ prev=partitioned_prev_preconds,
+ )
+ )
# Reshard output to have the same PSpec as input. This is required to avoid
# vmap seeing the full set of statistics.
partitioned_preconditioners = pjit.with_sharding_constraint(
- partitioned_preconditioners, pspec_for_partition)
+ partitioned_preconditioners, pspec_for_partition
+ )
# Recombine the outputs at each core.
- preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners,
- statistics_partition_spec)
- metrics = pjit.with_sharding_constraint(partitioned_metrics,
- jax.sharding.PartitionSpec())
+ preconditioners = pjit.with_sharding_constraint(
+ partitioned_preconditioners, statistics_partition_spec
+ )
+ metrics = pjit.with_sharding_constraint(
+ partitioned_metrics, jax.sharding.PartitionSpec()
+ )
return preconditioners, metrics
- def _pmap_compute_preconditioners(states,
- step,
- statistics,
- num_statistics_per_state,
- original_shapes,
- exponents,
- max_size,
- prev_preconditioners):
+ def _pmap_compute_preconditioners(
+ states,
+ step,
+ statistics,
+ num_statistics_per_state,
+ original_shapes,
+ exponents,
+ max_size,
+ prev_preconditioners,
+ ):
"""Computes preconditioners for given statistics in states in PMAP mode.
- Args:
- states: A list of optimizer states.
- step: Current step number
- statistics: A list of statistics for all variables (for every dim)
- num_statistics_per_state: Number of statistis per state to reconstruct
- output states.
- original_shapes: A list of shapes of the statistics.
- exponents: Exponent power to use for inverse-pth roots.
- max_size: Maximum dim of the statistics to pad.
- prev_preconditioners: Previously available preconditioner.
-
- Returns:
- New optimizer states after computing the preconditioner.
- """
+ Args:
+ states: A list of optimizer states.
+ step: Current step number
+ statistics: A list of statistics for all variables (for every dim)
+ num_statistics_per_state: Number of statistis per state to reconstruct
+ output states.
+ original_shapes: A list of shapes of the statistics.
+ exponents: Exponent power to use for inverse-pth roots.
+ max_size: Maximum dim of the statistics to pad.
+ prev_preconditioners: Previously available preconditioner.
+
+ Returns:
+ New optimizer states after computing the preconditioner.
+ """
if batch_axis_name:
num_devices = lax.psum(1, batch_axis_name)
else:
@@ -2103,13 +2236,15 @@ def _pmap_compute_preconditioners(states,
num_statistics = len(statistics)
# Pad statistics and exponents to next multiple of num_devices.
packed_statistics = [
- pad_square_matrix(stat, max_size) for stat in statistics
+ pad_square_matrix(stat, max_size) for stat in statistics
]
to_pad = -num_statistics % num_devices
- packed_statistics.extend([
+ packed_statistics.extend(
+ [
jnp.eye(max_size, dtype=packed_statistics[0].dtype)
for _ in range(to_pad)
- ])
+ ]
+ )
exponents.extend([1 for _ in range(to_pad)])
paddings = [len(stat) for stat in statistics] + [0] * to_pad
@@ -2119,7 +2254,8 @@ def _pmap_compute_preconditioners(states,
if reuse_preconditioner:
assert len(prev_preconditioners) == num_statistics
packed_preconditioners = pad_and_maybe_zero_preconditioners(
- prev_preconditioners, len(packed_statistics), max_size, step)
+ prev_preconditioners, len(packed_statistics), max_size, step
+ )
else:
packed_preconditioners = None
@@ -2132,10 +2268,10 @@ def _internal_inverse_pth_root_all():
if batch_axis_name:
current_replica = lax.axis_index(batch_axis_name)
preconditioners, metrics = _matrix_inverse_pth_root_vmap(
- all_statistics[current_replica],
- all_exponents[current_replica],
- all_paddings[current_replica],
- _maybe_ix(all_preconditioners, current_replica),
+ all_statistics[current_replica],
+ all_exponents[current_replica],
+ all_paddings[current_replica],
+ _maybe_ix(all_preconditioners, current_replica),
)
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
metrics = jax.lax.all_gather(metrics, batch_axis_name)
@@ -2143,14 +2279,15 @@ def _internal_inverse_pth_root_all():
metrics_flat = jax.tree.map(unbatch, metrics)
else:
preconditioners, metrics = _matrix_inverse_pth_root_vmap(
- all_statistics[0],
- all_exponents[0],
- all_paddings[0],
- _maybe_ix(all_preconditioners, 0),
+ all_statistics[0],
+ all_exponents[0],
+ all_paddings[0],
+ _maybe_ix(all_preconditioners, 0),
)
preconditioners_flat = unbatch(jnp.stack([preconditioners]))
metrics = jax.tree.map(
- functools.partial(jnp.expand_dims, axis=0), metrics)
+ functools.partial(jnp.expand_dims, axis=0), metrics
+ )
metrics_flat = jax.tree.map(unbatch, metrics)
return preconditioners_flat, metrics_flat
@@ -2163,40 +2300,57 @@ def _internal_inverse_pth_root_all():
# shaped tensors. Note statistics will be ignored as we are passing in
# a large error value.
preconditioners_init = [
- s[:, :precond_dim(s.shape[0])] for s in packed_statistics
+ s[:, : precond_dim(s.shape[0])] for s in packed_statistics
]
n = len(packed_statistics)
metrics_init = jax.tree.map(
- lambda x: [x] * n,
- default_training_metrics().replace(
- inverse_pth_root_errors=inverse_failure_threshold))
+ lambda x: [x] * n,
+ default_training_metrics().replace(
+ inverse_pth_root_errors=inverse_failure_threshold
+ ),
+ )
init_state = [preconditioners_init, metrics_init]
preconditioners_flat, metrics_flat = efficient_cond(
- perform_step, _internal_inverse_pth_root_all, init_state)
+ perform_step, _internal_inverse_pth_root_all, init_state
+ )
def _skip(error):
condition = jnp.logical_or(
- jnp.isnan(error), error >= inverse_failure_threshold)
+ jnp.isnan(error), error >= inverse_failure_threshold
+ )
return condition.astype(error.dtype)
def _select_preconditioner(error, new_p, old_p):
return lax.cond(
- _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
+ )
new_preconditioners_flat = []
new_errors_flat = metrics_flat.inverse_pth_root_errors
- for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
- prev_preconditioners, new_errors_flat):
+ for p, shape, prev_p, error in zip(
+ preconditioners_flat,
+ original_shapes,
+ prev_preconditioners,
+ new_errors_flat,
+ ):
new_preconditioners_flat.append(
- _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
+ )
- assert len(states) == (len(num_statistics_per_state),
- f"{len(states)} vs {len(num_statistics_per_state)}")
+ assert len(states) == (
+ len(num_statistics_per_state),
+ f'{len(states)} vs {len(num_statistics_per_state)}',
+ )
assert len(new_preconditioners_flat) == num_statistics
assert len(new_errors_flat) == len(packed_statistics), (
- len(new_errors_flat), len(packed_statistics))
+ len(new_errors_flat),
+ len(packed_statistics),
+ )
assert len(new_errors_flat) == num_statistics + to_pad, (
- len(new_errors_flat), num_statistics, to_pad)
+ len(new_errors_flat),
+ num_statistics,
+ to_pad,
+ )
# Add back empty preconditioners so we that we can set the optimizer state.
preconditioners_for_states = []
@@ -2206,26 +2360,31 @@ def _select_preconditioner(error, new_p, old_p):
if num_statistics == 0:
preconditioners_for_states.append([])
metrics_for_states.append(
- init_training_metrics(0, generate_training_metrics))
+ init_training_metrics(0, generate_training_metrics)
+ )
else:
- preconditioners_for_state = new_preconditioners_flat[idx:idx +
- num_statistics]
+ preconditioners_for_state = new_preconditioners_flat[
+ idx : idx + num_statistics
+ ]
assert len(state.statistics) == len(preconditioners_for_state)
preconditioners_for_states.append(preconditioners_for_state)
if generate_training_metrics:
# pylint:disable=cell-var-from-loop Used immediately.
metrics_for_state = jax.tree.map(
- lambda x: jnp.stack(x[idx:idx + num_statistics]),
- metrics_flat,
- is_leaf=lambda x: isinstance(x, list))
+ lambda x: jnp.stack(x[idx : idx + num_statistics]),
+ metrics_flat,
+ is_leaf=lambda x: isinstance(x, list),
+ )
assert jax.tree_util.tree_all(
- jax.tree.map(lambda x: len(state.statistics) == len(x),
- metrics_for_state))
+ jax.tree.map(
+ lambda x: len(state.statistics) == len(x), metrics_for_state
+ )
+ )
# If we skipped preconditioner computation, record old metrics.
- metrics_for_state = efficient_cond(perform_step,
- lambda: [metrics_for_state],
- [state.training_metrics])[0]
+ metrics_for_state = efficient_cond(
+ perform_step, lambda: [metrics_for_state], [state.training_metrics]
+ )[0]
# pylint:enable=cell-var-from-loop
else:
metrics_for_state = optax.MaskedNode()
@@ -2234,32 +2393,36 @@ def _select_preconditioner(error, new_p, old_p):
idx += num_statistics
new_states = []
for state, new_preconditioners, new_metrics in zip(
- states, preconditioners_for_states, metrics_for_states):
+ states, preconditioners_for_states, metrics_for_states
+ ):
# Note the preconditioner may have been skipped, but we still update the
# metrics with the new error values; whether the preconditioner that's
# actively being used is stale can be derived from the new_metrics
# being greater than the failure threshold.
new_states.append(
- ParameterStats(state.diagonal_statistics,
- state.statistics,
- new_preconditioners,
- state.diagonal_momentum,
- state.momentum,
- new_metrics))
+ ParameterStats(
+ state.diagonal_statistics,
+ state.statistics,
+ new_preconditioners,
+ state.diagonal_momentum,
+ state.momentum,
+ new_metrics,
+ )
+ )
return new_states
def _compute_preconditioners(states, params, step):
"""Computes preconditioners for given statistics in states.
- Args:
- states: A list of optimizer states.
- params: A list of params.
- step: Current step number
+ Args:
+ states: A list of optimizer states.
+ params: A list of params.
+ step: Current step number
- Returns:
- New optimizer states after computing the preconditioner.
- """
+ Returns:
+ New optimizer states after computing the preconditioner.
+ """
statistics = []
num_statistics_per_state = []
original_shapes = []
@@ -2274,8 +2437,11 @@ def _compute_preconditioners(states, params, step):
if num_statistics > 0:
preconditioner = preconditioner_from_params(param)
for statistic in state.statistics:
- exponents.append(preconditioner.exponent_for_preconditioner(
- ) if exponent_override == 0 else exponent_override)
+ exponents.append(
+ preconditioner.exponent_for_preconditioner()
+ if exponent_override == 0
+ else exponent_override
+ )
original_shapes_for_state.append(statistic.shape)
max_size = max(max_size, statistic.shape[0])
@@ -2283,14 +2449,16 @@ def _compute_preconditioners(states, params, step):
prev_preconditioners.extend(state.preconditioners)
original_shapes.extend(original_shapes_for_state)
- return _pmap_compute_preconditioners(states,
- step,
- statistics,
- num_statistics_per_state,
- original_shapes,
- exponents,
- max_size,
- prev_preconditioners)
+ return _pmap_compute_preconditioners(
+ states,
+ step,
+ statistics,
+ num_statistics_per_state,
+ original_shapes,
+ exponents,
+ max_size,
+ prev_preconditioners,
+ )
def _transform_grad(grad, state, param, step):
"""Transform per-parameter gradients."""
@@ -2298,21 +2466,25 @@ def _transform_grad(grad, state, param, step):
sgd_update = grad
new_diagonal_statistics = state.diagonal_statistics
- if (graft_type == GraftingType.ADAGRAD or
- graft_type == GraftingType.ADAGRAD_NORMALIZED):
-
+ if (
+ graft_type == GraftingType.ADAGRAD
+ or graft_type == GraftingType.ADAGRAD_NORMALIZED
+ ):
scaled_grad = grad
if graft_type == GraftingType.ADAGRAD_NORMALIZED:
scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON)
new_diagonal_statistics = (
- state.diagonal_statistics.to_float() + jnp.square(scaled_grad))
+ state.diagonal_statistics.to_float() + jnp.square(scaled_grad)
+ )
adagrad_update = scaled_grad / (
- jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
+ )
grafting_update = adagrad_update
- elif (graft_type == GraftingType.RMSPROP or
- graft_type == GraftingType.RMSPROP_NORMALIZED):
-
+ elif (
+ graft_type == GraftingType.RMSPROP
+ or graft_type == GraftingType.RMSPROP_NORMALIZED
+ ):
scaled_grad = grad
if graft_type == GraftingType.RMSPROP_NORMALIZED:
scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON)
@@ -2321,15 +2493,19 @@ def _transform_grad(grad, state, param, step):
w2 = jnp.where(beta2 == 1.0, beta2, 1.0 - beta2)
new_diagonal_statistics = (
- w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad))
+ w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad)
+ )
rmsprop_update = scaled_grad / (
- jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
+ )
if clip_by_scaled_gradient_norm:
scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
- jnp.sqrt(float(rmsprop_update.size)))
+ jnp.sqrt(float(rmsprop_update.size))
+ )
clipping_denom = jnp.maximum(
- 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
+ 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
+ )
rmsprop_update /= clipping_denom
grafting_update = rmsprop_update
@@ -2349,12 +2525,14 @@ def _transform_grad(grad, state, param, step):
precond_grad = grad
if not _skip_preconditioning(param):
- precond_grad = preconditioner.preconditioned_grad(precond_grad,
- state.preconditioners)
+ precond_grad = preconditioner.preconditioned_grad(
+ precond_grad, state.preconditioners
+ )
else:
if graft_type == GraftingType.NONE:
- logging.error("skipping preconditioning without grafting for param %s",
- param)
+ logging.error(
+ 'skipping preconditioning without grafting for param %s', param
+ )
precond_grad = grafting_update
grafting_update_norm = jnp.linalg.norm(grafting_update)
@@ -2369,39 +2547,49 @@ def _transform_grad(grad, state, param, step):
shampoo_update_with_wd = shampoo_update
grafting_update_with_wd = grafting_update
- if (weight_decay != 0 and weight_decay is not None and
- not decoupled_weight_decay):
+ if (
+ weight_decay != 0
+ and weight_decay is not None
+ and not decoupled_weight_decay
+ ):
shampoo_update_with_wd = shampoo_update + weight_decay * param
grafting_update_with_wd = grafting_update + weight_decay * param
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
shampoo_update_with_wd_momentum = (
- state.momentum * beta1 + w * shampoo_update_with_wd)
+ state.momentum * beta1 + w * shampoo_update_with_wd
+ )
grafting_update_with_wd_momentum = (
- state.diagonal_momentum * beta1 + w * grafting_update_with_wd)
+ state.diagonal_momentum * beta1 + w * grafting_update_with_wd
+ )
run_shampoo = (step >= start_preconditioning_step).astype(
- grafting_update_with_wd_momentum.dtype)
+ grafting_update_with_wd_momentum.dtype
+ )
momentum_update = (
- run_shampoo * shampoo_update_with_wd_momentum +
- (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
+ run_shampoo * shampoo_update_with_wd_momentum
+ + (1.0 - run_shampoo) * grafting_update_with_wd_momentum
+ )
wd_update = (
- run_shampoo * shampoo_update_with_wd +
- (1.0 - run_shampoo) * grafting_update_with_wd)
+ run_shampoo * shampoo_update_with_wd
+ + (1.0 - run_shampoo) * grafting_update_with_wd
+ )
nesterov_momentum_update = momentum_update
if nesterov:
nesterov_momentum_update = w * wd_update + beta1 * momentum_update
- if (weight_decay != 0 and weight_decay is not None and
- decoupled_weight_decay):
+ if (
+ weight_decay != 0 and weight_decay is not None and decoupled_weight_decay
+ ):
nesterov_momentum_update = (
- nesterov_momentum_update + lr * weight_decay * param)
+ nesterov_momentum_update + lr * weight_decay * param
+ )
momentum_multiplier = lr if decoupled_learning_rate else 1.0
transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update
@@ -2409,26 +2597,28 @@ def _transform_grad(grad, state, param, step):
new_diagonal_momentum = grafting_update_with_wd_momentum
new_momentum = shampoo_update_with_wd_momentum
- param_stats = ParameterStats(new_diagonal_statistics,
- state.statistics,
- state.preconditioners,
- new_diagonal_momentum,
- new_momentum,
- state.training_metrics)
+ param_stats = ParameterStats(
+ new_diagonal_statistics,
+ state.statistics,
+ state.preconditioners,
+ new_diagonal_momentum,
+ new_momentum,
+ state.training_metrics,
+ )
return transformed_update, param_stats
def update_fn(grads, state, params):
"""Transform the input gradient and update all statistics.
- Args:
- grads: the gradient tensors for the parameters and any custom
- gradients for preconditioners.
- state: a named tuple containing the state of the optimizer
- params: the parameters that should be updated.
+ Args:
+ grads: the gradient tensors for the parameters and any custom
+ gradients for preconditioners.
+ state: a named tuple containing the state of the optimizer
+ params: the parameters that should be updated.
- Returns:
- A tuple containing the new parameters and the new optimizer state.
- """
+ Returns:
+ A tuple containing the new parameters and the new optimizer state.
+ """
grads_custom = None
if custom_preconditioner and isinstance(grads, tuple):
grads, grads_custom = grads
@@ -2442,23 +2632,21 @@ def update_fn(grads, state, params):
stats_grads = treedef.flatten_up_to(grads_custom)
new_stats_flat = jax.tree.map(
- lambda g,
- s,
- p: _compute_stats(g, s, p, state.count),
- stats_grads,
- stats_flat,
- params_flat)
-
- new_stats_flat = _compute_preconditioners(new_stats_flat,
- params_flat,
- state.count)
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
+ stats_grads,
+ stats_flat,
+ params_flat,
+ )
+
+ new_stats_flat = _compute_preconditioners(
+ new_stats_flat, params_flat, state.count
+ )
outputs = jax.tree.map(
- lambda g,
- s,
- p: _transform_grad(g, s, p, state.count),
- grads_flat,
- new_stats_flat,
- params_flat)
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
+ grads_flat,
+ new_stats_flat,
+ params_flat,
+ )
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
updates = jax.tree_util.tree_unflatten(treedef, updates_flat)
new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat)
@@ -2472,9 +2660,10 @@ def update_fn(grads, state, params):
def _init_fns(unused_params):
return InitFnState(
- init_fn=opt_init_fn,
- pspec_fn=sharded_init_partition_spec_fn,
- shape_and_dtype_fn=sharded_init_shape_and_dtype_fn)
+ init_fn=opt_init_fn,
+ pspec_fn=sharded_init_partition_spec_fn,
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
+ )
opt_update_fn = sharded_update_fn
return optax.GradientTransformation(_init_fns, opt_update_fn)
diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py
index 2cd054062..8bf4d2dc5 100644
--- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py
+++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py
@@ -3,24 +3,27 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
-from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \
- distributed_shampoo
+from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import (
+ distributed_shampoo,
+)
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Shampoo optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -30,102 +33,116 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = distributed_shampoo(
- learning_rate=lr_schedule_fn,
- beta1=1.0 - hyperparameters.one_minus_beta1,
- beta2=hyperparameters.beta2,
- weight_decay=hyperparameters.weight_decay,
- batch_axis_name='batch',
- eigh=False)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ beta1=1.0 - hyperparameters.one_minus_beta1,
+ beta2=hyperparameters.beta2,
+ weight_decay=hyperparameters.weight_decay,
+ batch_axis_name='batch',
+ eigh=False,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -142,37 +159,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -208,14 +231,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json
index 9d804ba0e..58f6f4fd1 100644
--- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json
@@ -1,23 +1,29 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 1e-2, "max": 0.15, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 1e-2,
+ "max": 0.15,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json
index b8bd2ea49..5a7c27be7 100644
--- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/cosine_warmup.py b/reference_algorithms/target_setting_algorithms/cosine_warmup.py
index 116ebc555..6a2241732 100644
--- a/reference_algorithms/target_setting_algorithms/cosine_warmup.py
+++ b/reference_algorithms/target_setting_algorithms/cosine_warmup.py
@@ -1,35 +1,37 @@
"""Implementions of a linear warmup then cosine decay LR schedule."""
import optax
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=hyperparameters.warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=hyperparameters.warmup_steps,
+ )
cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn],
- boundaries=[hyperparameters.warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[hyperparameters.warmup_steps]
+ )
return schedule_fn
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup = LinearLR(
- optimizer,
- start_factor=1e-10,
- end_factor=1.,
- total_iters=hyperparameters.warmup_steps)
+ optimizer,
+ start_factor=1e-10,
+ end_factor=1.0,
+ total_iters=hyperparameters.warmup_steps,
+ )
cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer,
- schedulers=[warmup, cosine_decay],
- milestones=[hyperparameters.warmup_steps])
+ optimizer,
+ schedulers=[warmup, cosine_decay],
+ milestones=[hyperparameters.warmup_steps],
+ )
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
index cab6fd5f7..6061940a9 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
@@ -1,27 +1,17 @@
{
"learning_rate": {
- "feasible_points": [
- 0.0033313215673016375
- ]
+ "feasible_points": [0.0033313215673016375]
},
"beta1": {
- "feasible_points": [
- 0.948000082541717
- ]
+ "feasible_points": [0.948000082541717]
},
"beta2": {
- "feasible_points": [
- 0.9987934318891598
- ]
+ "feasible_points": [0.9987934318891598]
},
"warmup_steps": {
- "feasible_points": [
- 159
- ]
+ "feasible_points": [159]
},
"weight_decay": {
- "feasible_points": [
- 0.0035784380304876183
- ]
+ "feasible_points": [0.0035784380304876183]
}
}
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json
index bd6c9702f..110138607 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json
@@ -1,28 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.002517072211464665
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9908351643533544
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9859568907533993
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 799
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.12274552870237089
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.002517072211464665]
+ },
+ "beta1": {
+ "feasible_points": [0.9908351643533544]
+ },
+ "beta2": {
+ "feasible_points": [0.9859568907533993]
+ },
+ "warmup_steps": {
+ "feasible_points": [799]
+ },
+ "weight_decay": {
+ "feasible_points": [0.12274552870237089]
}
-
\ No newline at end of file
+}
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json
index 8d128dae1..a7f52681d 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json
@@ -1,28 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.05493199486120455
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.954922991734919
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9986188074995163
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 799
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.00011065469792077193
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.05493199486120455]
+ },
+ "beta1": {
+ "feasible_points": [0.954922991734919]
+ },
+ "beta2": {
+ "feasible_points": [0.9986188074995163]
+ },
+ "warmup_steps": {
+ "feasible_points": [799]
+ },
+ "weight_decay": {
+ "feasible_points": [0.00011065469792077193]
}
-
\ No newline at end of file
+}
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json
index a33ae2ff5..31ce92bd1 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json
@@ -1,28 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.001493629901423942
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9592129978682067
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9824918272399145
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 399
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.00038587516415285595
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.001493629901423942]
+ },
+ "beta1": {
+ "feasible_points": [0.9592129978682067]
+ },
+ "beta2": {
+ "feasible_points": [0.9824918272399145]
+ },
+ "warmup_steps": {
+ "feasible_points": [399]
+ },
+ "weight_decay": {
+ "feasible_points": [0.00038587516415285595]
}
-
\ No newline at end of file
+}
diff --git a/reference_algorithms/target_setting_algorithms/data_selection.py b/reference_algorithms/target_setting_algorithms/data_selection.py
index 5e70f9f8b..e0d9c0ee9 100644
--- a/reference_algorithms/target_setting_algorithms/data_selection.py
+++ b/reference_algorithms/target_setting_algorithms/data_selection.py
@@ -4,14 +4,15 @@
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
index d8b4ed1b9..894ebb9fb 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
@@ -1,37 +1,23 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.028609
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.981543
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1357
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.984398
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.01
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.000576
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.028609]
+ },
+ "beta1": {
+ "feasible_points": [0.981543]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [1357]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.984398]
+ },
+ "end_factor": {
+ "feasible_points": [0.01]
+ },
+ "weight_decay": {
+ "feasible_points": [0.000576]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json
index 7a49ea891..a3aa8ea08 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.008334676559764446
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8294338711079317
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8551723332825868
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 2714
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.01371235755699044
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.008334676559764446]
+ },
+ "beta1": {
+ "feasible_points": [0.8294338711079317]
+ },
+ "beta2": {
+ "feasible_points": [0.8551723332825868]
+ },
+ "warmup_steps": {
+ "feasible_points": [2714]
+ },
+ "weight_decay": {
+ "feasible_points": [0.01371235755699044]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json
index 5516242df..21c5ac87c 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.006173154695175443
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8496694604806512
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.4639437428687345
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1357
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.1679001017957879
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.006173154695175443]
+ },
+ "beta1": {
+ "feasible_points": [0.8496694604806512]
+ },
+ "beta2": {
+ "feasible_points": [0.4639437428687345]
+ },
+ "warmup_steps": {
+ "feasible_points": [1357]
+ },
+ "weight_decay": {
+ "feasible_points": [0.1679001017957879]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json
index c3f06a686..59d624fe9 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.04037951750205473
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9932215932637941
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9425306939334134
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 542
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.14877061239151607
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.04037951750205473]
+ },
+ "beta1": {
+ "feasible_points": [0.9932215932637941]
+ },
+ "beta2": {
+ "feasible_points": [0.9425306939334134]
+ },
+ "warmup_steps": {
+ "feasible_points": [542]
+ },
+ "weight_decay": {
+ "feasible_points": [0.14877061239151607]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
index 649487c48..941bac70e 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
@@ -1,42 +1,26 @@
{
- "learning_rate": {
- "feasible_points": [
- 4.131896390902391
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9274758113254791
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.9007765761611038
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.001
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 5.6687777311501786e-6
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.2
- ]
- }
+ "learning_rate": {
+ "feasible_points": [4.131896390902391]
+ },
+ "beta1": {
+ "feasible_points": [0.9274758113254791]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.9007765761611038]
+ },
+ "end_factor": {
+ "feasible_points": [0.001]
+ },
+ "weight_decay": {
+ "feasible_points": [5.6687777311501786e-6]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.2]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json
index 6524f5a5b..f72a4057d 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json
@@ -1,37 +1,23 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.3850582234619253
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9845129495436189
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.9504205232618159
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.001
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 1.7359160785435053e-5
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.2
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.3850582234619253]
+ },
+ "beta1": {
+ "feasible_points": [0.9845129495436189]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.9504205232618159]
+ },
+ "end_factor": {
+ "feasible_points": [0.001]
+ },
+ "weight_decay": {
+ "feasible_points": [1.7359160785435053e-5]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.2]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json
index 7ad32bb60..f58474cc8 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json
@@ -1,37 +1,23 @@
{
- "learning_rate": {
- "feasible_points": [
- 4.131896390902391
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9274758113254791
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.9007765761611038
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.01
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 5.6687777311501786e-6
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.2
- ]
- }
+ "learning_rate": {
+ "feasible_points": [4.131896390902391]
+ },
+ "beta1": {
+ "feasible_points": [0.9274758113254791]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.9007765761611038]
+ },
+ "end_factor": {
+ "feasible_points": [0.01]
+ },
+ "weight_decay": {
+ "feasible_points": [5.6687777311501786e-6]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.2]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json
index 4556c6235..c63a87214 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.01897755400372091
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9666072782043229
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.99681600289198
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.015653883841116094
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.01897755400372091]
+ },
+ "beta1": {
+ "feasible_points": [0.9666072782043229]
+ },
+ "beta2": {
+ "feasible_points": [0.99681600289198]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.015653883841116094]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
index 98360bdff..6c7501295 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0008445074561975979
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8895758153482813
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08135402759553023
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0008445074561975979]
+ },
+ "beta1": {
+ "feasible_points": [0.8895758153482813]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08135402759553023]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json
index 98360bdff..6c7501295 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0008445074561975979
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8895758153482813
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08135402759553023
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0008445074561975979]
+ },
+ "beta1": {
+ "feasible_points": [0.8895758153482813]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08135402759553023]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json
index 98360bdff..6c7501295 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0008445074561975979
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8895758153482813
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08135402759553023
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0008445074561975979]
+ },
+ "beta1": {
+ "feasible_points": [0.8895758153482813]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08135402759553023]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json
index d6f2053ff..94711417c 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.00026032497966327757
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9709035036599892
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.6572080806975734
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 13999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.03077045727617869
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.00026032497966327757]
+ },
+ "beta1": {
+ "feasible_points": [0.9709035036599892]
+ },
+ "beta2": {
+ "feasible_points": [0.6572080806975734]
+ },
+ "warmup_steps": {
+ "feasible_points": [13999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.03077045727617869]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py
index b64f0dfd6..3fa5e4955 100644
--- a/reference_algorithms/target_setting_algorithms/jax_adamw.py
+++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py
@@ -1,44 +1,54 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax."""
-from flax import jax_utils
+
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
+ update_params,
+)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_params
del model_state
del rng
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = cosine_warmup.jax_cosine_warmup(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
opt_init_fn, opt_update_fn = optax.adamw(
- learning_rate=lr_schedule_fn,
- b1=hyperparameters.beta1,
- b2=hyperparameters.beta2,
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay)
+ learning_rate=lr_schedule_fn,
+ b1=hyperparameters.beta1,
+ b2=hyperparameters.beta2,
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py
index a6c3d853b..92d3f3a8f 100644
--- a/reference_algorithms/target_setting_algorithms/jax_momentum.py
+++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py
@@ -2,25 +2,30 @@
from typing import Callable
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
+ update_params,
+)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload,
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=hyperparameters.beta1,
- nesterov=False)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=hyperparameters.beta1,
+ nesterov=False,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=hyperparameters.warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=hyperparameters.warmup_steps,
+ )
decay_steps = step_hint - hyperparameters.warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn],
- boundaries=[hyperparameters.warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn],
+ boundaries=[hyperparameters.warmup_steps],
+ )
return lr_schedule_fn
@@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py
index 597a43c9e..6883b17ab 100644
--- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py
+++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py
@@ -3,33 +3,35 @@
from typing import Any, Callable, NamedTuple, Optional, Union
import chex
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
+ update_params,
+)
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -61,19 +63,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
There seem to be multiple versions of NAdam. The original version is here
@@ -109,7 +114,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -117,6 +123,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -125,7 +132,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -141,31 +149,37 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
del rng
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = cosine_warmup.jax_cosine_warmup(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=hyperparameters.beta1,
- b2=hyperparameters.beta2,
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay)
+ learning_rate=lr_schedule_fn,
+ b1=hyperparameters.beta1,
+ b2=hyperparameters.beta2,
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py
index 0c11044fc..7f4d1cd86 100644
--- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py
+++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py
@@ -2,25 +2,30 @@
from typing import Callable
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
+ update_params,
+)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload,
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=hyperparameters.beta1,
- nesterov=True)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=hyperparameters.beta1,
+ nesterov=True,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=hyperparameters.warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=hyperparameters.warmup_steps,
+ )
decay_steps = step_hint - hyperparameters.warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn],
- boundaries=[hyperparameters.warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn],
+ boundaries=[hyperparameters.warmup_steps],
+ )
return lr_schedule_fn
@@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py
index 217228935..d9b12f5ca 100644
--- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py
+++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py
@@ -1,11 +1,12 @@
"""Update submission function in Jax."""
+
import functools
from typing import Any, Dict, List, Optional, Tuple
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from jax import lax
from algoperf import spec
@@ -13,75 +14,84 @@
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -98,32 +108,46 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long
- workload, opt_update_fn, model_state, optimizer_state,
- current_param_container, batch, per_device_rngs, grad_clip,
- label_smoothing)
+ new_optimizer_state, new_params, new_model_state, loss, grad_norm = (
+ pmapped_train_step( # pylint: disable=line-too-long
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
+ )
# Log loss, grad_norm.
- if ((global_step <= 100 or global_step % 500 == 0) and
- workload.metrics_logger is not None):
+ if (
+ global_step <= 100 or global_step % 500 == 0
+ ) and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
index 482a28931..936833cf3 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.002106913873888147
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8231189937738506
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8774571227688758
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1199
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.27590534177690645
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.002106913873888147]
+ },
+ "beta1": {
+ "feasible_points": [0.8231189937738506]
+ },
+ "beta2": {
+ "feasible_points": [0.8774571227688758]
+ },
+ "warmup_steps": {
+ "feasible_points": [1199]
+ },
+ "weight_decay": {
+ "feasible_points": [0.27590534177690645]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json
index 22f3376b4..faefa750e 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0007852999990476642
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.6994142393023162
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9918636824608852
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.07286322158086678
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0007852999990476642]
+ },
+ "beta1": {
+ "feasible_points": [0.6994142393023162]
+ },
+ "beta2": {
+ "feasible_points": [0.9918636824608852]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.07286322158086678]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json
index ad200c01b..16ab02525 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.000590120167916659
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.737199286155609
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.05919391544031072
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 9999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.14128519778326312
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.000590120167916659]
+ },
+ "beta1": {
+ "feasible_points": [0.737199286155609]
+ },
+ "beta2": {
+ "feasible_points": [0.05919391544031072]
+ },
+ "warmup_steps": {
+ "feasible_points": [9999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.14128519778326312]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json
index 8297cf0ae..d596dcd2b 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0014446807792420305
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.7427148812902895
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8993064520764248
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.06875136511682291
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0014446807792420305]
+ },
+ "beta1": {
+ "feasible_points": [0.7427148812902895]
+ },
+ "beta2": {
+ "feasible_points": [0.8993064520764248]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.06875136511682291]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json
index b31b711f7..dbcbecf78 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0035278622506232458
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8192305396005781
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.495850879212151
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.04339748256184769
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0035278622506232458]
+ },
+ "beta1": {
+ "feasible_points": [0.8192305396005781]
+ },
+ "beta2": {
+ "feasible_points": [0.495850879212151]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.04339748256184769]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
index e20a2dae1..bbea133cb 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.001308209823469072
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9731333693827139
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9981232922116359
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.16375311233774334
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.001308209823469072]
+ },
+ "beta1": {
+ "feasible_points": [0.9731333693827139]
+ },
+ "beta2": {
+ "feasible_points": [0.9981232922116359]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.16375311233774334]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
index 0a9bfb3cf..52fe59d84 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.004958460849689891
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.863744242567442
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.6291854735396584
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 720
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.1147386261512052
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.004958460849689891]
+ },
+ "beta1": {
+ "feasible_points": [0.863744242567442]
+ },
+ "beta2": {
+ "feasible_points": [0.6291854735396584]
+ },
+ "warmup_steps": {
+ "feasible_points": [720]
+ },
+ "weight_decay": {
+ "feasible_points": [0.1147386261512052]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json
index e76a48325..898fc9e36 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0020162740358935045
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9604907112078142
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8765457000160508
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3600
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.0006149579248633481
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0020162740358935045]
+ },
+ "beta1": {
+ "feasible_points": [0.9604907112078142]
+ },
+ "beta2": {
+ "feasible_points": [0.8765457000160508]
+ },
+ "warmup_steps": {
+ "feasible_points": [3600]
+ },
+ "weight_decay": {
+ "feasible_points": [0.0006149579248633481]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
index 55f70f9fc..94f150ad3 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0014446807792420305
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.7427148812902895
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8993064520764248
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1800
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.06875136511682291
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0014446807792420305]
+ },
+ "beta1": {
+ "feasible_points": [0.7427148812902895]
+ },
+ "beta2": {
+ "feasible_points": [0.8993064520764248]
+ },
+ "warmup_steps": {
+ "feasible_points": [1800]
+ },
+ "weight_decay": {
+ "feasible_points": [0.06875136511682291]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json
index e5f906688..517e4a455 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.003604759885558324
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9931094324430452
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9976871843749077
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 720
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.120077307855989
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.003604759885558324]
+ },
+ "beta1": {
+ "feasible_points": [0.9931094324430452]
+ },
+ "beta2": {
+ "feasible_points": [0.9976871843749077]
+ },
+ "warmup_steps": {
+ "feasible_points": [720]
+ },
+ "weight_decay": {
+ "feasible_points": [0.120077307855989]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
index 0f365a183..266f0f3f5 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 2.4917728606918423
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9449369031171744
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.861509027839639
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.001
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 1.2859640541025928e-7
- ]
- }
+ "learning_rate": {
+ "feasible_points": [2.4917728606918423]
+ },
+ "beta1": {
+ "feasible_points": [0.9449369031171744]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.861509027839639]
+ },
+ "end_factor": {
+ "feasible_points": [0.001]
+ },
+ "weight_decay": {
+ "feasible_points": [1.2859640541025928e-7]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json
index 0749f96d6..b5b4aad30 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.01897755400372091
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9666072782043229
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.99681600289198
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.015653883841116094
- ]
- }
-}
\ No newline at end of file
+ "learning_rate": {
+ "feasible_points": [0.01897755400372091]
+ },
+ "beta1": {
+ "feasible_points": [0.9666072782043229]
+ },
+ "beta2": {
+ "feasible_points": [0.99681600289198]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.015653883841116094]
+ }
+}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json
index d5af3b03e..b69114b88 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.001734480757979605
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.855609542347586
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9834185656478605
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.019843063335529494
- ]
- }
-}
\ No newline at end of file
+ "learning_rate": {
+ "feasible_points": [0.001734480757979605]
+ },
+ "beta1": {
+ "feasible_points": [0.855609542347586]
+ },
+ "beta2": {
+ "feasible_points": [0.9834185656478605]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.019843063335529494]
+ }
+}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json
index b9f83a5ed..e1512c02a 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.00027866530268792414
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9919340993463499
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9979843253162892
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.00032418357325210813
- ]
- }
-}
\ No newline at end of file
+ "learning_rate": {
+ "feasible_points": [0.00027866530268792414]
+ },
+ "beta1": {
+ "feasible_points": [0.9919340993463499]
+ },
+ "beta2": {
+ "feasible_points": [0.9979843253162892]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.00032418357325210813]
+ }
+}
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
index c87bdfb7d..14e8155d4 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
@@ -4,37 +4,44 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
+ update_params,
+)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
optimizer_state = {
- 'optimizer':
- torch.optim.AdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(hyperparameters.beta1, hyperparameters.beta2),
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': torch.optim.AdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(hyperparameters.beta1, hyperparameters.beta2),
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
target_setting_step_hint = int(0.75 * workload.step_hint)
optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup(
- target_setting_step_hint, hyperparameters, optimizer_state['optimizer'])
+ target_setting_step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
index 584caff39..a23b835eb 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
@@ -4,40 +4,47 @@
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_momentum import \
- create_lr_schedule_fn
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
-
-
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.jax_momentum import (
+ create_lr_schedule_fn,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
+ update_params,
+)
+
+
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=hyperparameters.beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=False),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=hyperparameters.beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=False,
+ ),
}
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will
# be multiplied with the base lr, so we have to divide by it here.
@@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
index a9dee1d79..d301a233f 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
@@ -8,43 +8,43 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
+ update_params,
+)
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
-
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -56,7 +56,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -64,7 +67,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -72,10 +76,10 @@ def __setstate__(self, state):
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
self._cuda_graph_capture_health_check()
loss = None
@@ -103,51 +107,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float):
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+):
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
- """
+ See NAdamW class for details.
+ """
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -185,28 +195,32 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(hyperparameters.beta1, hyperparameters.beta2),
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(hyperparameters.beta1, hyperparameters.beta2),
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
target_setting_step_hint = int(0.75 * workload.step_hint)
optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup(
- target_setting_step_hint, hyperparameters, optimizer_state['optimizer'])
+ target_setting_step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
index 8e10db4ef..3a6294a28 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
@@ -4,40 +4,47 @@
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_nesterov import \
- create_lr_schedule_fn
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
-
-
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
+ get_batch_size,
+)
+from reference_algorithms.target_setting_algorithms.jax_momentum import (
+ create_lr_schedule_fn,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
+ update_params,
+)
+
+
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=hyperparameters.beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=True),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=hyperparameters.beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=True,
+ ),
}
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will
# be multiplied with the base lr, so we have to divide by it here.
@@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
index bbfd8b0f2..0bef4548f 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
@@ -2,9 +2,9 @@
from typing import Any, Dict, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
+from absl import logging
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -13,18 +13,19 @@
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -36,26 +37,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -68,7 +73,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
if 'scheduler' in optimizer_state:
optimizer_state['scheduler'].step()
@@ -78,32 +84,39 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
diff --git a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
index f0ef45daa..aee18d976 100644
--- a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0017486387539278373
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9326607383586145
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9955159689799007
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08121616522670176
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.0
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0017486387539278373]
+ },
+ "beta1": {
+ "feasible_points": [0.9326607383586145]
+ },
+ "beta2": {
+ "feasible_points": [0.9955159689799007]
+ },
+ "warmup_steps": {
+ "feasible_points": [1999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08121616522670176]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.0]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json
index 266cdedbb..e1ce2229f 100644
--- a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.000590120167916659
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.737199286155609
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.05919391544031072
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 9999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.14128519778326312
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.0
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.000590120167916659]
+ },
+ "beta1": {
+ "feasible_points": [0.737199286155609]
+ },
+ "beta2": {
+ "feasible_points": [0.05919391544031072]
+ },
+ "warmup_steps": {
+ "feasible_points": [9999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.14128519778326312]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.0]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json
index d288d9a49..0ed0f832a 100644
--- a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.000872041489644454
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.45562164405092065
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9982167124443476
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 4999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.01536114562763022
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.1
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.000872041489644454]
+ },
+ "beta1": {
+ "feasible_points": [0.45562164405092065]
+ },
+ "beta2": {
+ "feasible_points": [0.9982167124443476]
+ },
+ "warmup_steps": {
+ "feasible_points": [4999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.01536114562763022]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json
index 1327bcb38..d1b3045c7 100644
--- a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0003477912008450351
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9936632117510711
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9967873550453692
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 9999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.04120183162940475
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.0
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0003477912008450351]
+ },
+ "beta1": {
+ "feasible_points": [0.9936632117510711]
+ },
+ "beta2": {
+ "feasible_points": [0.9967873550453692]
+ },
+ "warmup_steps": {
+ "feasible_points": [9999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.04120183162940475]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.0]
+ }
}
diff --git a/scoring/algoperf_v05/generate_held_out_workloads.py b/scoring/algoperf_v05/generate_held_out_workloads.py
index 647dc3c3d..e9ebf6a53 100644
--- a/scoring/algoperf_v05/generate_held_out_workloads.py
+++ b/scoring/algoperf_v05/generate_held_out_workloads.py
@@ -2,49 +2,51 @@
import os
import struct
-from absl import app
-from absl import flags
-from absl import logging
import numpy as np
+from absl import app, flags, logging
flags.DEFINE_integer(
- 'held_out_workloads_seed',
- None,
- 'Random seed for scoring.'
- 'AlgoPerf v0.5 seed: 3438810845')
-flags.DEFINE_string('output_filename',
- 'held_out_workloads.json',
- 'Path to file to record sampled held_out workloads.')
+ 'held_out_workloads_seed',
+ None,
+ 'Random seed for scoring.AlgoPerf v0.5 seed: 3438810845',
+)
+flags.DEFINE_string(
+ 'output_filename',
+ 'held_out_workloads.json',
+ 'Path to file to record sampled held_out workloads.',
+)
FLAGS = flags.FLAGS
HELD_OUT_WORKLOADS = {
- 'librispeech': [
- 'librispeech_conformer_attention_temperature',
- 'librispeech_conformer_layernorm',
- # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure
- 'librispeech_deepspeech_no_resnet',
- 'librispeech_deepspeech_norm_and_spec_aug',
- 'librispeech_deepspeech_tanh'
- ],
- 'imagenet': [
- 'imagenet_resnet_silu',
- 'imagenet_resnet_gelu',
- 'imagenet_resnet_large_bn_init',
- 'imagenet_vit_glu',
- 'imagenet_vit_post_ln',
- 'imagenet_vit_map'
- ],
- 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'],
- 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'],
- 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'],
- 'criteo1tb': [
- 'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'
- ]
+ 'librispeech': [
+ 'librispeech_conformer_attention_temperature',
+ 'librispeech_conformer_layernorm',
+ # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure
+ 'librispeech_deepspeech_no_resnet',
+ 'librispeech_deepspeech_norm_and_spec_aug',
+ 'librispeech_deepspeech_tanh',
+ ],
+ 'imagenet': [
+ 'imagenet_resnet_silu',
+ 'imagenet_resnet_gelu',
+ 'imagenet_resnet_large_bn_init',
+ 'imagenet_vit_glu',
+ 'imagenet_vit_post_ln',
+ 'imagenet_vit_map',
+ ],
+ 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'],
+ 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'],
+ 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'],
+ 'criteo1tb': [
+ 'criteo1tb_layernorm',
+ 'criteo1tb_embed_init',
+ 'criteo1tb_resnet',
+ ],
}
def save_held_out_workloads(held_out_workloads, filename):
- with open(filename, "w") as f:
+ with open(filename, 'w') as f:
json.dump(held_out_workloads, f)
@@ -63,7 +65,7 @@ def main(_):
sampled_index = rng.integers(len(v))
sampled_held_out_workloads.append(v[sampled_index])
- logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}")
+ logging.info(f'Sampled held-out workloads: {sampled_held_out_workloads}')
save_held_out_workloads(sampled_held_out_workloads, output_filename)
diff --git a/scoring/algoperf_v05/score_submissions.py b/scoring/algoperf_v05/score_submissions.py
index 8cc06b15f..6ef931a54 100644
--- a/scoring/algoperf_v05/score_submissions.py
+++ b/scoring/algoperf_v05/score_submissions.py
@@ -16,57 +16,62 @@
import os
import pickle
-from absl import app
-from absl import flags
-from absl import logging
import numpy as np
import pandas as pd
import performance_profile
import scoring_utils
+from absl import app, flags, logging
from tabulate import tabulate
flags.DEFINE_string(
- 'submission_directory',
- None,
- 'Path to submission directory containing experiment directories.')
+ 'submission_directory',
+ None,
+ 'Path to submission directory containing experiment directories.',
+)
flags.DEFINE_string(
- 'output_dir',
- 'scoring_results',
- 'Path to save performance profile artifacts, submission_summaries and results files.'
+ 'output_dir',
+ 'scoring_results',
+ 'Path to save performance profile artifacts, submission_summaries and results files.',
+)
+flags.DEFINE_boolean(
+ 'compute_performance_profiles',
+ False,
+ 'Whether or not to compute the performance profiles.',
)
-flags.DEFINE_boolean('compute_performance_profiles',
- False,
- 'Whether or not to compute the performance profiles.')
flags.DEFINE_boolean(
- 'strict',
- False,
- 'Whether to enforce scoring criteria on variant performance and on'
- '5-trial median performance. Note that during official scoring this '
- 'flag will be set to True.')
+ 'strict',
+ False,
+ 'Whether to enforce scoring criteria on variant performance and on'
+ '5-trial median performance. Note that during official scoring this '
+ 'flag will be set to True.',
+)
flags.DEFINE_boolean(
- 'self_tuning_ruleset',
- False,
- 'Whether to score on self-tuning ruleset or externally tuned ruleset')
+ 'self_tuning_ruleset',
+ False,
+ 'Whether to score on self-tuning ruleset or externally tuned ruleset',
+)
flags.DEFINE_string(
- 'save_results_to_filename',
- None,
- 'Filename to save the processed results that are fed into the performance profile functions.'
+ 'save_results_to_filename',
+ None,
+ 'Filename to save the processed results that are fed into the performance profile functions.',
)
flags.DEFINE_string(
- 'load_results_from_filename',
- None,
- 'Filename to load processed results from that are fed into performance profile functions'
+ 'load_results_from_filename',
+ None,
+ 'Filename to load processed results from that are fed into performance profile functions',
)
flags.DEFINE_string(
- 'exclude_submissions',
- '',
- 'Optional comma seperated list of names of submissions to exclude from scoring.'
+ 'exclude_submissions',
+ '',
+ 'Optional comma seperated list of names of submissions to exclude from scoring.',
)
FLAGS = flags.FLAGS
def get_summary_df(workload, workload_df, include_test_split=False):
- validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ validation_metric, validation_target = (
+ scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ )
is_minimized = performance_profile.check_if_minimized(validation_metric)
target_op = operator.le if is_minimized else operator.ge
@@ -79,47 +84,69 @@ def get_summary_df(workload, workload_df, include_test_split=False):
summary_df['val target metric name'] = validation_metric
summary_df['val target metric value'] = validation_target
- summary_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ summary_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
summary_df['best metric value on val'] = workload_df[validation_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on val'] = workload_df[validation_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on val']],
- axis=1)
- workload_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ lambda x: x['accumulated_submission_time'][x['index best eval on val']],
+ axis=1,
+ )
+ workload_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
workload_df['index to target on val'] = workload_df.apply(
- lambda x: np.argmax(target_op(x[validation_metric], validation_target))
- if x['val target reached'] else np.nan,
- axis=1)
+ lambda x: np.argmax(target_op(x[validation_metric], validation_target))
+ if x['val target reached']
+ else np.nan,
+ axis=1,
+ )
summary_df['time to target on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][int(x[
- 'index to target on val'])] if x['val target reached'] else np.inf,
- axis=1)
+ lambda x: x['accumulated_submission_time'][int(x['index to target on val'])]
+ if x['val target reached']
+ else np.inf,
+ axis=1,
+ )
# test metrics
if include_test_split:
- test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test')
+ test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(
+ workload, split='test'
+ )
summary_df['test target metric name'] = test_metric
summary_df['test target metric value'] = test_target
- summary_df['test target reached'] = workload_df[test_metric].apply(
- lambda x: target_op(x, test_target)).apply(np.any)
+ summary_df['test target reached'] = (
+ workload_df[test_metric]
+ .apply(lambda x: target_op(x, test_target))
+ .apply(np.any)
+ )
summary_df['best metric value on test'] = workload_df[test_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on test'] = workload_df[test_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on test (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on test']
- ],
- axis=1)
+ lambda x: x['accumulated_submission_time'][x['index best eval on test']],
+ axis=1,
+ )
summary_df['time to target on test (s)'] = summary_df.apply(
- lambda x: x['time to best eval on test (s)']
- if x['test target reached'] else np.inf,
- axis=1)
+ lambda x: x['time to best eval on test (s)']
+ if x['test target reached']
+ else np.inf,
+ axis=1,
+ )
return summary_df
@@ -133,7 +160,8 @@ def get_submission_summary(df, include_test_split=True):
print(df)
for workload, group in df.groupby('workload'):
summary_df = get_summary_df(
- workload, group, include_test_split=include_test_split)
+ workload, group, include_test_split=include_test_split
+ )
dfs.append(summary_df)
df = pd.concat(dfs)
@@ -164,61 +192,64 @@ def main(_):
# Optionally read results to filename
if FLAGS.load_results_from_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename),
- 'rb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb'
+ ) as f:
results = pickle.load(f)
else:
for team in os.listdir(FLAGS.submission_directory):
for submission in os.listdir(
- os.path.join(FLAGS.submission_directory, team)):
+ os.path.join(FLAGS.submission_directory, team)
+ ):
print(submission)
if submission in FLAGS.exclude_submissions.split(','):
continue
- experiment_path = os.path.join(FLAGS.submission_directory,
- team,
- submission)
+ experiment_path = os.path.join(
+ FLAGS.submission_directory, team, submission
+ )
df = scoring_utils.get_experiment_df(experiment_path)
results[submission] = df
summary_df = get_submission_summary(df)
with open(
- os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
- 'w') as fout:
+ os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w'
+ ) as fout:
summary_df.to_csv(fout)
# Optionally save results to filename
if FLAGS.save_results_to_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename),
- 'wb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb'
+ ) as f:
pickle.dump(results, f)
if not FLAGS.strict:
logging.warning(
- 'You are running with strict=False. This will relax '
- 'scoring criteria on the held-out workloads, number of trials and number '
- 'of studies. Your score may not be an accurate representation '
- 'under competition scoring rules. To enforce the criteria set strict=True.'
+ 'You are running with strict=False. This will relax '
+ 'scoring criteria on the held-out workloads, number of trials and number '
+ 'of studies. Your score may not be an accurate representation '
+ 'under competition scoring rules. To enforce the criteria set strict=True.'
)
if FLAGS.compute_performance_profiles:
performance_profile_df = performance_profile.compute_performance_profiles(
- results,
- time_col='score',
- min_tau=1.0,
- max_tau=4.0,
- reference_submission_tag=None,
- num_points=100,
- scale='linear',
- verbosity=0,
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
- strict=FLAGS.strict,
- output_dir=FLAGS.output_dir,
+ results,
+ time_col='score',
+ min_tau=1.0,
+ max_tau=4.0,
+ reference_submission_tag=None,
+ num_points=100,
+ scale='linear',
+ verbosity=0,
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ strict=FLAGS.strict,
+ output_dir=FLAGS.output_dir,
)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
performance_profile.plot_performance_profiles(
- performance_profile_df, 'score', save_dir=FLAGS.output_dir)
+ performance_profile_df, 'score', save_dir=FLAGS.output_dir
+ )
performance_profile_str = tabulate(
- performance_profile_df.T, headers='keys', tablefmt='psql')
+ performance_profile_df.T, headers='keys', tablefmt='psql'
+ )
logging.info(f'Performance profile:\n {performance_profile_str}')
scores = compute_leaderboard_score(performance_profile_df)
scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv'))
diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py
index d0e5bf70b..1740a6dce 100644
--- a/scoring/compute_speedups.py
+++ b/scoring/compute_speedups.py
@@ -2,39 +2,39 @@
import pickle
-from absl import app
-from absl import flags
import numpy as np
import pandas as pd
-from performance_profile import BASE_WORKLOADS
-from performance_profile import get_workloads_time_to_target
+from absl import app, flags
+from performance_profile import BASE_WORKLOADS, get_workloads_time_to_target
from scipy import stats
flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.')
flags.DEFINE_string(
- 'base',
- 'prize_qualification_baseline',
- 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.'
+ 'base',
+ 'prize_qualification_baseline',
+ 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.',
)
flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.')
-flags.DEFINE_boolean('self_tuning_ruleset',
- False,
- 'Whether the self-tuning ruleset is being scored.')
-flags.DEFINE_boolean('save_results',
- False,
- 'Whether to save the results to disk.')
+flags.DEFINE_boolean(
+ 'self_tuning_ruleset',
+ False,
+ 'Whether the self-tuning ruleset is being scored.',
+)
+flags.DEFINE_boolean(
+ 'save_results', False, 'Whether to save the results to disk.'
+)
FLAGS = flags.FLAGS
# These are the old budgets, used in the first iteration of the competition.
MAX_BUDGETS = {
- 'criteo1tb': 7703,
- 'fastmri': 8859,
- 'imagenet_resnet': 63_008,
- 'imagenet_vit': 77_520,
- 'librispeech_conformer': 61_068,
- 'librispeech_deepspeech': 55_506,
- 'ogbg': 18_477,
- 'wmt': 48_151,
+ 'criteo1tb': 7703,
+ 'fastmri': 8859,
+ 'imagenet_resnet': 63_008,
+ 'imagenet_vit': 77_520,
+ 'librispeech_conformer': 61_068,
+ 'librispeech_deepspeech': 55_506,
+ 'ogbg': 18_477,
+ 'wmt': 48_151,
}
@@ -63,16 +63,16 @@ def compute_speedup():
# Compute median over runtimes for both training algorithms
base_results = get_workloads_time_to_target(
- results[FLAGS.base],
- FLAGS.base,
- time_col="score",
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ results[FLAGS.base],
+ FLAGS.base,
+ time_col='score',
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)
comparison_results = get_workloads_time_to_target(
- results[FLAGS.comparison],
- FLAGS.comparison,
- time_col="score",
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ results[FLAGS.comparison],
+ FLAGS.comparison,
+ time_col='score',
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)
# Merge results
@@ -85,20 +85,23 @@ def compute_speedup():
merged_results = merged_results.apply(replace_inf, axis=1)
# Compute speedup
- merged_results['speedup'] = merged_results[
- f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
+ merged_results['speedup'] = (
+ merged_results[f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
+ )
speedups = merged_results['speedup'].to_numpy()
mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups
print(merged_results, end='\n\n')
print(
- f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}"
+ f'Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1 - mean_speedup):.1%}'
)
if FLAGS.save_results:
# Optionally save results to disk
- print("Saving results to disk...")
- filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv'
+ print('Saving results to disk...')
+ filename = (
+ f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1 - mean_speedup):.1%}.csv'
+ )
merged_results.to_csv(filename)
diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py
index 05026a0c7..4f2ae9c57 100644
--- a/scoring/performance_profile.py
+++ b/scoring/performance_profile.py
@@ -25,21 +25,22 @@
The keys in this dictionary should match the workload identifiers used in
the dictionary of submissions.
"""
+
import itertools
import json
import operator
import os
import re
-from absl import logging
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
+from absl import logging
from tabulate import tabulate
-from algoperf.workloads.workloads import get_base_workload_name
import algoperf.workloads.workloads as workloads_registry
+from algoperf.workloads.workloads import get_base_workload_name
from scoring import scoring_utils
WORKLOADS = workloads_registry.WORKLOADS
@@ -49,9 +50,9 @@
# Open json file to read heldout workloads
# TODO: This probably shouldn't be hardcoded but passed as an argument.\
try:
- with open("held_out_workloads_algoperf_v05.json", "r") as f:
+ with open('held_out_workloads_algoperf_v05.json', 'r') as f:
HELDOUT_WORKLOADS = json.load(f)
-except:
+except FileNotFoundError:
HELDOUT_WORKLOADS = None
# These global variables have to be set according to the current set of
@@ -64,22 +65,22 @@
NUM_STUDIES = 3
MIN_EVAL_METRICS = [
- 'ce_loss',
- 'error_rate',
- 'ctc_loss',
- 'wer',
- 'l1_loss',
- 'loss',
+ 'ce_loss',
+ 'error_rate',
+ 'ctc_loss',
+ 'wer',
+ 'l1_loss',
+ 'loss',
]
MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']
-#MPL params
+# MPL params
mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches
mpl.rcParams['font.family'] = 'serif'
-mpl.rcParams['font.serif'] = [
- 'Times New Roman'
-] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice
+mpl.rcParams['font.serif'] = ['Times New Roman'] + mpl.rcParams[
+ 'font.serif'
+] # Add Times New Roman as first choice
mpl.rcParams['font.size'] = 22
mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures
@@ -87,16 +88,17 @@
mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed
mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(
- color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
- "#9467bd"]) # Example color cycle (consider ColorBrewer or viridis)
+ color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
+) # Example color cycle (consider ColorBrewer or viridis)
mpl.rcParams['axes.labelsize'] = 22 # Axis label font size
mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size
mpl.rcParams['ytick.labelsize'] = 20
# Legends and Gridlines
mpl.rcParams['legend.fontsize'] = 20 # Legend font size
-mpl.rcParams[
- 'legend.loc'] = 'best' # Let matplotlib decide the best legend location
+mpl.rcParams['legend.loc'] = (
+ 'best' # Let matplotlib decide the best legend location
+)
mpl.rcParams['axes.grid'] = True # Enable grid
mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency
@@ -113,7 +115,8 @@ def generate_eval_cols(metrics):
MINIMIZE_REGISTRY = {k: True for k in generate_eval_cols(MIN_EVAL_METRICS)}
MINIMIZE_REGISTRY.update(
- {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)})
+ {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)}
+)
MINIMIZE_REGISTRY['train_cost'] = True
@@ -125,13 +128,15 @@ def check_if_minimized(col_name):
if col in col_name:
return MINIMIZE_REGISTRY[col]
- raise ValueError(f'Column {col_name} not found in `MINIMIZE_REGISTRY` as '
- 'either a column name or a substring of a column name.')
+ raise ValueError(
+ f'Column {col_name} not found in `MINIMIZE_REGISTRY` as '
+ 'either a column name or a substring of a column name.'
+ )
-def get_best_trial_index(workload_df,
- validation_metric,
- validation_target=None):
+def get_best_trial_index(
+ workload_df, validation_metric, validation_target=None
+):
"""Get the eval index in which a workload reaches the target metric_col.
Args:
@@ -150,7 +155,8 @@ def get_best_trial_index(workload_df,
op = operator.le if is_minimized else operator.ge
validation_target_reached = validation_series.apply(
- lambda x: op(x, validation_target))
+ lambda x: op(x, validation_target)
+ )
target_reached = pd.Series(validation_target_reached)
# Remove trials that never reach the target
@@ -166,12 +172,14 @@ def get_best_trial_index(workload_df,
return trial, index_reached[trial]
-def get_workloads_time_to_target(submission,
- submission_name,
- time_col='global_step',
- verbosity=1,
- self_tuning_ruleset=False,
- strict=False):
+def get_workloads_time_to_target(
+ submission,
+ submission_name,
+ time_col='global_step',
+ verbosity=1,
+ self_tuning_ruleset=False,
+ strict=False,
+):
"""Get times to target for each workload in a submission.
Args:
@@ -191,60 +199,72 @@ def get_workloads_time_to_target(submission,
if num_workloads != NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS:
if strict:
raise ValueError(
- f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
- f'but found {num_workloads} workloads for {submission_name}.')
- logging.warning(
f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
- f'but found {num_workloads} workloads for {submission_name}.')
+ f'but found {num_workloads} workloads for {submission_name}.'
+ )
+ logging.warning(
+ f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
+ f'but found {num_workloads} workloads for {submission_name}.'
+ )
# For each workload get submission time get the submission times to target.
for workload, group in submission.groupby('workload'):
- validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload)
+ validation_metric, validation_target = (
+ scoring_utils.get_workload_metrics_and_targets(workload)
+ )
# Check number of studies
time_vals_per_study = []
num_studies = len(group.groupby('study'))
if num_studies != NUM_STUDIES:
if strict:
- raise ValueError(f'Expecting {NUM_STUDIES} studies for workload '
- f'{workload} but found {num_studies} studies '
- f'for {submission_name}.')
+ raise ValueError(
+ f'Expecting {NUM_STUDIES} studies for workload '
+ f'{workload} but found {num_studies} studies '
+ f'for {submission_name}.'
+ )
else:
- logging.warning(f'Expecting {NUM_STUDIES} studies for workload '
- f'{workload} but found {num_studies} studies '
- f'for {submission_name}.')
+ logging.warning(
+ f'Expecting {NUM_STUDIES} studies for workload '
+ f'{workload} but found {num_studies} studies '
+ f'for {submission_name}.'
+ )
# For each study check trials
for study, group in group.groupby('study'):
-
# Check number of trials per study
num_trials = len(group)
if num_trials != NUM_TRIALS and not self_tuning_ruleset:
if strict:
raise ValueError(
- f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
- f'{workload} but found {num_trials} trials '
- f'for {submission_name}.')
+ f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
+ f'{workload} but found {num_trials} trials '
+ f'for {submission_name}.'
+ )
else:
logging.warning(
- f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
- f'{workload} but found {num_trials} trials '
- f'for {submission_name}.')
+ f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
+ f'{workload} but found {num_trials} trials '
+ f'for {submission_name}.'
+ )
# Get trial and time index that reaches target
trial_idx, time_idx = get_best_trial_index(
- group, validation_metric, validation_target)
+ group, validation_metric, validation_target
+ )
if time_idx > -1:
time_val = group[time_col].loc[trial_idx][time_idx]
else:
time_val = float('inf')
time_vals_per_study.append(time_val)
- workloads.append({
+ workloads.append(
+ {
'submission': submission_name,
'workload': re.sub(r'_(jax|pytorch)$', '', workload),
time_col: np.median(time_vals_per_study),
- })
+ }
+ )
df = pd.DataFrame.from_records(workloads)
df = df.pivot(index='submission', columns='workload', values=time_col)
@@ -252,7 +272,6 @@ def get_workloads_time_to_target(submission,
def variant_criteria_filter(base_workload, variant_workload):
-
def filter(x):
try:
if x[variant_workload] == np.inf:
@@ -269,17 +288,19 @@ def filter(x):
return filter
-def compute_performance_profiles(submissions,
- time_col='global_step',
- min_tau=1.0,
- max_tau=None,
- reference_submission_tag=None,
- num_points=100,
- scale='linear',
- verbosity=0,
- strict=False,
- self_tuning_ruleset=False,
- output_dir=None):
+def compute_performance_profiles(
+ submissions,
+ time_col='global_step',
+ min_tau=1.0,
+ max_tau=None,
+ reference_submission_tag=None,
+ num_points=100,
+ scale='linear',
+ verbosity=0,
+ strict=False,
+ self_tuning_ruleset=False,
+ output_dir=None,
+):
"""Compute performance profiles for a set of submission by some time column.
Args:
@@ -308,16 +329,20 @@ def compute_performance_profiles(submissions,
for submission_tag, submission in submissions.items():
logging.info(
- f'\nComputing performance profile with respect to `{time_col}` for '
- f'{submission_tag}')
+ f'\nComputing performance profile with respect to `{time_col}` for '
+ f'{submission_tag}'
+ )
# Get time to targets for each submission across studies and trials
dfs.append(
- get_workloads_time_to_target(submission,
- submission_tag,
- time_col,
- verbosity,
- self_tuning_ruleset,
- strict))
+ get_workloads_time_to_target(
+ submission,
+ submission_tag,
+ time_col,
+ verbosity,
+ self_tuning_ruleset,
+ strict,
+ )
+ )
df = pd.concat(dfs)
# Restrict to base and sampled held-out workloads
# (ignore the additional workload variants of the baseline
@@ -335,7 +360,8 @@ def compute_performance_profiles(submissions,
# If base do not have finite score set variant score to inf
base_workload = get_base_workload_name(workload)
df[workload] = df.apply(
- variant_criteria_filter(workload, base_workload), axis=1)
+ variant_criteria_filter(workload, base_workload), axis=1
+ )
# Set score to inf if not within 4x of fastest submission
best_scores = df.min(axis=0)
@@ -347,17 +373,20 @@ def compute_performance_profiles(submissions,
# If variants do not have finite score set base_workload score to inf
base_workload = get_base_workload_name(workload)
df[base_workload] = df.apply(
- variant_criteria_filter(base_workload, workload), axis=1)
+ variant_criteria_filter(base_workload, workload), axis=1
+ )
df = df[BASE_WORKLOADS]
if verbosity > 0:
logging.info('\n`{time_col}` to reach target:')
- with pd.option_context('display.max_rows',
- None,
- 'display.max_columns',
- None,
- 'display.width',
- 1000):
+ with pd.option_context(
+ 'display.max_rows',
+ None,
+ 'display.max_columns',
+ None,
+ 'display.width',
+ 1000,
+ ):
logging.info(df)
# Divide by the fastest.
@@ -368,12 +397,14 @@ def compute_performance_profiles(submissions,
if verbosity > 0:
logging.info('\n`{time_col}` to reach target normalized to best:')
- with pd.option_context('display.max_rows',
- None,
- 'display.max_columns',
- None,
- 'display.width',
- 1000):
+ with pd.option_context(
+ 'display.max_rows',
+ None,
+ 'display.max_columns',
+ None,
+ 'display.width',
+ 1000,
+ ):
logging.info(df)
# If no max_tau is supplied, choose the value of tau that would plot all non
@@ -385,7 +416,8 @@ def compute_performance_profiles(submissions,
points = np.linspace(min_tau, max_tau, num=num_points)
elif scale == 'log':
points = np.logspace(
- np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0)
+ np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0
+ )
def rho(r, tau):
return (r <= tau).sum(axis=1) / NUM_BASE_WORKLOADS
@@ -431,11 +463,9 @@ def maybe_save_df_to_csv(save_dir, df, path, **to_csv_kwargs):
df.to_csv(fout, **to_csv_kwargs)
-def plot_performance_profiles(perf_df,
- df_col,
- scale='linear',
- save_dir=None,
- figsize=(30, 10)):
+def plot_performance_profiles(
+ perf_df, df_col, scale='linear', save_dir=None, figsize=(30, 10)
+):
"""Plot performance profiles.
Args:
@@ -462,6 +492,6 @@ def plot_performance_profiles(perf_df,
fig.legend(bbox_to_anchor=(1.0, 1.0))
plt.tight_layout()
maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}')
- maybe_save_df_to_csv(save_dir,
- perf_df,
- f'performance_profile_{df_col_display}.csv')
+ maybe_save_df_to_csv(
+ save_dir, perf_df, f'performance_profile_{df_col_display}.csv'
+ )
diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py
index f07dc8cdd..b48509f02 100644
--- a/scoring/score_submissions.py
+++ b/scoring/score_submissions.py
@@ -17,57 +17,62 @@
import os
import pickle
-from absl import app
-from absl import flags
-from absl import logging
import numpy as np
import pandas as pd
import performance_profile
import scoring_utils
+from absl import app, flags, logging
from tabulate import tabulate
flags.DEFINE_string(
- 'submission_directory',
- None,
- 'Path to submission directory containing experiment directories.')
+ 'submission_directory',
+ None,
+ 'Path to submission directory containing experiment directories.',
+)
flags.DEFINE_string(
- 'output_dir',
- 'scoring_results',
- 'Path to save performance profile artifacts, submission_summaries and results files.'
+ 'output_dir',
+ 'scoring_results',
+ 'Path to save performance profile artifacts, submission_summaries and results files.',
+)
+flags.DEFINE_boolean(
+ 'compute_performance_profiles',
+ False,
+ 'Whether or not to compute the performance profiles.',
)
-flags.DEFINE_boolean('compute_performance_profiles',
- False,
- 'Whether or not to compute the performance profiles.')
flags.DEFINE_boolean(
- 'strict',
- False,
- 'Whether to enforce scoring criteria on variant performance and on'
- '5-trial median performance. Note that during official scoring this '
- 'flag will be set to True.')
+ 'strict',
+ False,
+ 'Whether to enforce scoring criteria on variant performance and on'
+ '5-trial median performance. Note that during official scoring this '
+ 'flag will be set to True.',
+)
flags.DEFINE_boolean(
- 'self_tuning_ruleset',
- False,
- 'Whether to score on self-tuning ruleset or externally tuned ruleset')
+ 'self_tuning_ruleset',
+ False,
+ 'Whether to score on self-tuning ruleset or externally tuned ruleset',
+)
flags.DEFINE_string(
- 'save_results_to_filename',
- None,
- 'Filename to save the processed results that are fed into the performance profile functions.'
+ 'save_results_to_filename',
+ None,
+ 'Filename to save the processed results that are fed into the performance profile functions.',
)
flags.DEFINE_string(
- 'load_results_from_filename',
- None,
- 'Filename to load processed results from that are fed into performance profile functions'
+ 'load_results_from_filename',
+ None,
+ 'Filename to load processed results from that are fed into performance profile functions',
)
flags.DEFINE_string(
- 'exclude_submissions',
- '',
- 'Optional comma seperated list of names of submissions to exclude from scoring.'
+ 'exclude_submissions',
+ '',
+ 'Optional comma seperated list of names of submissions to exclude from scoring.',
)
FLAGS = flags.FLAGS
def get_summary_df(workload, workload_df, include_test_split=False):
- validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ validation_metric, validation_target = (
+ scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ )
is_minimized = performance_profile.check_if_minimized(validation_metric)
target_op = operator.le if is_minimized else operator.ge
@@ -80,47 +85,69 @@ def get_summary_df(workload, workload_df, include_test_split=False):
summary_df['val target metric name'] = validation_metric
summary_df['val target metric value'] = validation_target
- summary_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ summary_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
summary_df['best metric value on val'] = workload_df[validation_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on val'] = workload_df[validation_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on val']],
- axis=1)
- workload_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ lambda x: x['accumulated_submission_time'][x['index best eval on val']],
+ axis=1,
+ )
+ workload_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
workload_df['index to target on val'] = workload_df.apply(
- lambda x: np.argmax(target_op(x[validation_metric], validation_target))
- if x['val target reached'] else np.nan,
- axis=1)
+ lambda x: np.argmax(target_op(x[validation_metric], validation_target))
+ if x['val target reached']
+ else np.nan,
+ axis=1,
+ )
summary_df['time to target on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][int(x[
- 'index to target on val'])] if x['val target reached'] else np.inf,
- axis=1)
+ lambda x: x['accumulated_submission_time'][int(x['index to target on val'])]
+ if x['val target reached']
+ else np.inf,
+ axis=1,
+ )
# test metrics
if include_test_split:
- test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test')
+ test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(
+ workload, split='test'
+ )
summary_df['test target metric name'] = test_metric
summary_df['test target metric value'] = test_target
- summary_df['test target reached'] = workload_df[test_metric].apply(
- lambda x: target_op(x, test_target)).apply(np.any)
+ summary_df['test target reached'] = (
+ workload_df[test_metric]
+ .apply(lambda x: target_op(x, test_target))
+ .apply(np.any)
+ )
summary_df['best metric value on test'] = workload_df[test_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on test'] = workload_df[test_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on test (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on test']
- ],
- axis=1)
+ lambda x: x['accumulated_submission_time'][x['index best eval on test']],
+ axis=1,
+ )
summary_df['time to target on test (s)'] = summary_df.apply(
- lambda x: x['time to best eval on test (s)']
- if x['test target reached'] else np.inf,
- axis=1)
+ lambda x: x['time to best eval on test (s)']
+ if x['test target reached']
+ else np.inf,
+ axis=1,
+ )
return summary_df
@@ -134,7 +161,8 @@ def get_submission_summary(df, include_test_split=True):
print(df)
for workload, group in df.groupby('workload'):
summary_df = get_summary_df(
- workload, group, include_test_split=include_test_split)
+ workload, group, include_test_split=include_test_split
+ )
dfs.append(summary_df)
df = pd.concat(dfs)
@@ -161,13 +189,13 @@ def compute_leaderboard_score(df, normalize=True):
def main(_):
results = {}
os.makedirs(FLAGS.output_dir, exist_ok=True)
- logging.info(f"Scoring submissions in {FLAGS.submission_directory}")
+ logging.info(f'Scoring submissions in {FLAGS.submission_directory}')
# Optionally read results to filename
if FLAGS.load_results_from_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename),
- 'rb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb'
+ ) as f:
results = pickle.load(f)
else:
for submission in os.listdir(FLAGS.submission_directory):
@@ -179,44 +207,46 @@ def main(_):
results[submission] = df
summary_df = get_submission_summary(df)
with open(
- os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
- 'w') as fout:
+ os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w'
+ ) as fout:
summary_df.to_csv(fout)
# Optionally save results to filename
if FLAGS.save_results_to_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename),
- 'wb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb'
+ ) as f:
pickle.dump(results, f)
if not FLAGS.strict:
logging.warning(
- 'You are running with strict=False. This will relax '
- 'scoring criteria on the held-out workloads, number of trials and number '
- 'of studies. Your score may not be an accurate representation '
- 'under competition scoring rules. To enforce the criteria set strict=True.'
+ 'You are running with strict=False. This will relax '
+ 'scoring criteria on the held-out workloads, number of trials and number '
+ 'of studies. Your score may not be an accurate representation '
+ 'under competition scoring rules. To enforce the criteria set strict=True.'
)
if FLAGS.compute_performance_profiles:
performance_profile_df = performance_profile.compute_performance_profiles(
- results,
- time_col='score',
- min_tau=1.0,
- max_tau=4.0,
- reference_submission_tag=None,
- num_points=100,
- scale='linear',
- verbosity=0,
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
- strict=FLAGS.strict,
- output_dir=FLAGS.output_dir,
+ results,
+ time_col='score',
+ min_tau=1.0,
+ max_tau=4.0,
+ reference_submission_tag=None,
+ num_points=100,
+ scale='linear',
+ verbosity=0,
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ strict=FLAGS.strict,
+ output_dir=FLAGS.output_dir,
)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
performance_profile.plot_performance_profiles(
- performance_profile_df, 'score', save_dir=FLAGS.output_dir)
+ performance_profile_df, 'score', save_dir=FLAGS.output_dir
+ )
performance_profile_str = tabulate(
- performance_profile_df.T, headers='keys', tablefmt='psql')
+ performance_profile_df.T, headers='keys', tablefmt='psql'
+ )
logging.info(f'Performance profile:\n {performance_profile_str}')
scores = compute_leaderboard_score(performance_profile_df)
scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv'))
diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py
index ac513816e..5be6c790c 100644
--- a/scoring/scoring_utils.py
+++ b/scoring/scoring_utils.py
@@ -4,8 +4,8 @@
import os
import re
-from absl import logging
import pandas as pd
+from absl import logging
import algoperf.workloads.workloads as workloads_registry
@@ -13,7 +13,7 @@
METRICS_LINE_REGEX = '(.*) Metrics: ({.*})'
TRIAL_DIR_REGEX = 'trial_(\d+)'
MEASUREMENTS_FILENAME = 'eval_measurements.csv'
-TIMESTAMP = r"-\d{4}(-\d{2}){5}"
+TIMESTAMP = r'-\d{4}(-\d{2}){5}'
WORKLOADS = workloads_registry.WORKLOADS
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
@@ -22,12 +22,11 @@
#### File IO helper functions ###
def get_logfile_paths(logdir):
- """Gets all files ending in .log in logdir
- """
+ """Gets all files ending in .log in logdir"""
filenames = os.listdir(logdir)
logfile_paths = []
for f in filenames:
- if f.endswith(".log"):
+ if f.endswith('.log'):
f = os.path.join(logdir, f)
logfile_paths.append(f)
return logfile_paths
@@ -36,23 +35,23 @@ def get_logfile_paths(logdir):
### Logfile reading helper functions ###
def decode_metrics_line(line):
"""Convert metrics line to dict.
- Args:
- line: str
-
- Returns:
- dict_of_lists: dict where keys are metric names and vals
- are lists of values.
- e.g. {'loss':[5.1, 3.2, 1.0],
- 'step':[100, 200, 300]}
- """
+ Args:
+ line: str
+
+ Returns:
+ dict_of_lists: dict where keys are metric names and vals
+ are lists of values.
+ e.g. {'loss':[5.1, 3.2, 1.0],
+ 'step':[100, 200, 300]}
+ """
eval_results = []
dict_str = re.match(METRICS_LINE_REGEX, line).group(2)
- dict_str = dict_str.replace("'", "\"")
- dict_str = dict_str.replace("(", "")
- dict_str = dict_str.replace(")", "")
- dict_str = dict_str.replace("DeviceArray", "")
- dict_str = dict_str.replace(", dtype=float32", "")
- dict_str = dict_str.replace("nan", "0")
+ dict_str = dict_str.replace("'", '"')
+ dict_str = dict_str.replace('(', '')
+ dict_str = dict_str.replace(')', '')
+ dict_str = dict_str.replace('DeviceArray', '')
+ dict_str = dict_str.replace(', dtype=float32', '')
+ dict_str = dict_str.replace('nan', '0')
metrics_dict = json.loads(dict_str)
for item in metrics_dict['eval_results']:
if isinstance(item, dict):
@@ -73,18 +72,18 @@ def decode_metrics_line(line):
def get_trials_dict(logfile):
- """Get a dict of dicts with metrics for each
- tuning run.
-
- Returns:
- trials_dict: Dict of dicts where outer dict keys
- are trial indices and inner dict key-value pairs
- are metrics and list of values.
- e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0],
- 'step':[100, 200, 300]},
- 'trial_1': {'loss':[5.1, 3.2, 1.0],
- 'step':[100, 200, 300]}}
- """
+ """Get a dict of dicts with metrics for each
+ tuning run.
+
+ Returns:
+ trials_dict: Dict of dicts where outer dict keys
+ are trial indices and inner dict key-value pairs
+ are metrics and list of values.
+ e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0],
+ 'step':[100, 200, 300]},
+ 'trial_1': {'loss':[5.1, 3.2, 1.0],
+ 'step':[100, 200, 300]}}
+ """
trial = 0
metrics_lines = {}
with open(logfile, 'r') as f:
@@ -100,16 +99,16 @@ def get_trials_dict(logfile):
### Results formatting helper functions ###
def get_trials_df_dict(logfile):
- """Get a dict with dataframes with metrics for each
- tuning run.
- Preferable format for saving dataframes for tables.
- Args:
- logfile: str path to logfile.
-
- Returns:
- DataFrame where indices are index of eval and
- columns are metric names.
- """
+ """Get a dict with dataframes with metrics for each
+ tuning run.
+ Preferable format for saving dataframes for tables.
+ Args:
+ logfile: str path to logfile.
+
+ Returns:
+ DataFrame where indices are index of eval and
+ columns are metric names.
+ """
trials_dict = get_trials_dict(logfile)
trials_df_dict = {}
for trial, metrics in trials_dict.items():
@@ -119,20 +118,20 @@ def get_trials_df_dict(logfile):
def get_trials_df(logfile):
"""Gets a df of per trial results from a logfile.
- Args:
- experiment_dir: str
-
- Returns:
- df: DataFrame where indices are trials, columns are
- metric names and values are lists.
- e.g
- +---------+-----------------+-----------------+
- | | loss | step |
- |---------+-----------------+-----------------|
- | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] |
- | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] |
- +---------+-----------------+-----------------+
- """
+ Args:
+ experiment_dir: str
+
+ Returns:
+ df: DataFrame where indices are trials, columns are
+ metric names and values are lists.
+ e.g
+ +---------+-----------------+-----------------+
+ | | loss | step |
+ |---------+-----------------+-----------------|
+ | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] |
+ | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] |
+ +---------+-----------------+-----------------+
+ """
trials_dict = get_trials_dict(logfile)
df = pd.DataFrame(trials_dict).transpose()
return df
@@ -141,13 +140,13 @@ def get_trials_df(logfile):
## Get scoring code
def get_experiment_df(experiment_dir):
"""Gets a df of per trial results from an experiment dir.
- The output df can be provided as input to
- performance_profile.compute_performance_profiles.
+ The output df can be provided as input to
+ performance_profile.compute_performance_profiles.
Args:
- experiment_dir: path to experiment directory containing
- results for workloads. Measurements from experiments
- sharing the same prefix but different timestamps are
- collected together.
+ experiment_dir: path to experiment directory containing
+ results for workloads. Measurements from experiments
+ sharing the same prefix but different timestamps are
+ collected together.
The directory structure is assumed to be:
+ experiment_dir
+ study
@@ -156,9 +155,9 @@ def get_experiment_df(experiment_dir):
- eval_measurements.csv
Returns:
- df: DataFrame where indices are trials, columns are
+ df: DataFrame where indices are trials, columns are
metric names and values are lists of length num evals.
- e.g
+ e.g
+----+-----------+--------+----------------------------+--------------------+--------------------+
| | workload | study |trial | validation/accuracy| score |
|----+-----------+--------+----------------------------+--------------------+--------------------|
@@ -167,39 +166,41 @@ def get_experiment_df(experiment_dir):
"""
df = pd.DataFrame()
paths = filter(
- lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir,
- glob.glob(f"{experiment_dir}*"))
+ lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir,
+ glob.glob(f'{experiment_dir}*'),
+ )
for experiment_dir in paths:
study_dirs = os.listdir(experiment_dir)
for study_dir in study_dirs:
workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir))
workload_dirs = [
- w for w in workload_dirs
- if os.path.isdir(os.path.join(experiment_dir, study_dir, w))
+ w
+ for w in workload_dirs
+ if os.path.isdir(os.path.join(experiment_dir, study_dir, w))
]
print(workload_dirs)
for workload in workload_dirs:
data = {
- 'workload': workload,
+ 'workload': workload,
}
logging.info(os.path.join(experiment_dir, study_dir, workload))
trial_dirs = [
- t for t in os.listdir(
- os.path.join(experiment_dir, study_dir, workload))
- if re.match(TRIAL_DIR_REGEX, t)
+ t
+ for t in os.listdir(os.path.join(experiment_dir, study_dir, workload))
+ if re.match(TRIAL_DIR_REGEX, t)
]
for trial in trial_dirs:
eval_measurements_filepath = os.path.join(
- experiment_dir,
- study_dir,
- workload,
- trial,
- MEASUREMENTS_FILENAME,
+ experiment_dir,
+ study_dir,
+ workload,
+ trial,
+ MEASUREMENTS_FILENAME,
)
try:
trial_df = pd.read_csv(eval_measurements_filepath)
- except FileNotFoundError as e:
+ except FileNotFoundError:
logging.info(f'Could not read {eval_measurements_filepath}')
continue
data['trial'] = (trial, experiment_dir)
@@ -221,14 +222,16 @@ def get_workload_metrics_and_targets(workload, split='validation'):
# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
- BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + f'{framework}',
- 'workload.py')
+ BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + f'{framework}',
+ 'workload.py',
+ )
workload_init_kwargs = {}
workload_obj = workloads_registry.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs=workload_init_kwargs)
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs=workload_init_kwargs,
+ )
metric_name = workload_obj.target_metric_name
if split == 'validation':
metric = f'validation/{metric_name}'
diff --git a/scoring/test_performance_profile.py b/scoring/test_performance_profile.py
deleted file mode 100644
index 166c82d09..000000000
--- a/scoring/test_performance_profile.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import os
-
-from absl.testing import absltest
-
-from scoring import performance_profile
-from scoring import scoring_utils
-
-
-class Test(absltest.TestCase):
-
- def test_get_workloads_time_to_target(self):
- # TODO(kasimbeg)
- pass
-
- def test_get_best_trial_index(self):
- # TODO(kasimbeg)
- pass
-
- def test_compute_performance_profiles(self):
- # TODO(kasimbeg)
- pass
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py
index 7509e3e46..64e141976 100644
--- a/scoring/test_scoring_utils.py
+++ b/scoring/test_scoring_utils.py
@@ -1,8 +1,5 @@
-import os
-
from absl.testing import absltest
-from scoring import performance_profile
from scoring import scoring_utils
TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log'
@@ -11,7 +8,6 @@
class Test(absltest.TestCase):
-
def test_get_trials_dict(self):
trials_dict = scoring_utils.get_trials_dict(TEST_LOGFILE)
self.assertEqual(len(trials_dict['1']['global_step']), NUM_EVALS)
diff --git a/scoring/utils/package_logs.py b/scoring/utils/package_logs.py
index 074075abf..e341570a1 100644
--- a/scoring/utils/package_logs.py
+++ b/scoring/utils/package_logs.py
@@ -3,11 +3,11 @@
python3 package_logs.py --experiment_dir --destination_dir
"""
+
import os
import shutil
-from absl import app
-from absl import flags
+from absl import app, flags
flags.DEFINE_string('experiment_dir', None, 'Path to experiment.')
flags.DEFINE_string('destination_dir', None, 'Path to save submission logs')
@@ -17,10 +17,10 @@
def move_logs(experiment_dir, destination_dir):
"""Copy files from experiment path to destination directory.
- Args:
- experiment_dir: Path to experiment dir.
- destination_dir: Path to destination dir.
- """
+ Args:
+ experiment_dir: Path to experiment dir.
+ destination_dir: Path to destination dir.
+ """
if not os.path.exists(experiment_dir):
raise IOError(f'Directory does not exist {destination_dir}')
diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py
index 683fb3c63..e2de01130 100644
--- a/scoring/utils/run_workloads.py
+++ b/scoring/utils/run_workloads.py
@@ -16,101 +16,111 @@
import subprocess
import time
-from absl import app
-from absl import flags
-from absl import logging
+from absl import app, flags, logging
+import docker
from algoperf import random_utils as prng
from algoperf.workloads.workloads import get_base_workload_name
-import docker
flags.DEFINE_string(
- 'docker_image_url',
- 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev',
- 'URL to docker image')
+ 'docker_image_url',
+ 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev',
+ 'URL to docker image',
+)
flags.DEFINE_integer(
- 'run_percentage',
- 100,
- 'Percentage of max num steps to run for.'
- 'Must set the flag enable_step_budget to True for this to take effect.')
-flags.DEFINE_string('experiment_name',
- 'my_experiment',
- 'Name of top sub directory in experiment dir.')
-flags.DEFINE_boolean('rsync_data',
- True,
- 'Whether or not to transfer the data from GCP w rsync.')
+ 'run_percentage',
+ 100,
+ 'Percentage of max num steps to run for.'
+ 'Must set the flag enable_step_budget to True for this to take effect.',
+)
+flags.DEFINE_string(
+ 'experiment_name',
+ 'my_experiment',
+ 'Name of top sub directory in experiment dir.',
+)
+flags.DEFINE_boolean(
+ 'rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.'
+)
flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.')
flags.DEFINE_string(
- 'submission_path',
- 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py',
- 'Path to reference submission.')
+ 'submission_path',
+ 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py',
+ 'Path to reference submission.',
+)
flags.DEFINE_string(
- 'tuning_search_space',
- 'prize_qualification_baselines/external_tuning/tuning_search_space.json',
- 'Path to tuning search space.')
+ 'tuning_search_space',
+ 'prize_qualification_baselines/external_tuning/tuning_search_space.json',
+ 'Path to tuning search space.',
+)
flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.')
flags.DEFINE_boolean(
- 'dry_run',
- False,
- 'Whether or not to actually run the docker containers. '
- 'If False, simply print the docker run commands. ')
+ 'dry_run',
+ False,
+ 'Whether or not to actually run the docker containers. '
+ 'If False, simply print the docker run commands. ',
+)
flags.DEFINE_enum(
- 'tuning_ruleset',
- 'external',
- enum_values=['external', 'self'],
- help='Can be either external of self.')
+ 'tuning_ruleset',
+ 'external',
+ enum_values=['external', 'self'],
+ help='Can be either external of self.',
+)
flags.DEFINE_integer('num_studies', 5, 'Number of studies to run')
flags.DEFINE_integer('study_start_index', None, 'Start index for studies.')
flags.DEFINE_integer('study_end_index', None, 'End index for studies.')
flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.')
-flags.DEFINE_integer('hparam_start_index',
- None,
- 'Start index for tuning trials.')
+flags.DEFINE_integer(
+ 'hparam_start_index', None, 'Start index for tuning trials.'
+)
flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.')
flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.')
-flags.DEFINE_integer('submission_id',
- 0,
- 'Submission ID to generate study and hparam seeds.')
-flags.DEFINE_string('held_out_workloads_config_path',
- None,
- 'Path to config containing held-out workloads')
+flags.DEFINE_integer(
+ 'submission_id', 0, 'Submission ID to generate study and hparam seeds.'
+)
flags.DEFINE_string(
- 'workload_metadata_path',
- None,
- 'Path to config containing dataset and maximum number of steps per workload.'
- 'The default values of these are set to the full budgets as determined '
- 'via the target-setting procedure. '
- 'We provide workload_metadata_external_tuning.json and '
- 'workload_metadata_self_tuning.json as references.'
- 'Note that training will be interrupted at either the set maximum number '
- 'of steps or the fixed workload maximum run time, whichever comes first. '
- 'If your algorithm has a smaller per step time than our baselines '
- 'you may want to increase the number of steps per workload.')
+ 'held_out_workloads_config_path',
+ None,
+ 'Path to config containing held-out workloads',
+)
flags.DEFINE_string(
- 'workloads',
- None,
- 'String representing a comma separated list of workload names.'
- 'If not None, only run this workload, else run all workloads in workload_metadata_path.'
+ 'workload_metadata_path',
+ None,
+ 'Path to config containing dataset and maximum number of steps per workload.'
+ 'The default values of these are set to the full budgets as determined '
+ 'via the target-setting procedure. '
+ 'We provide workload_metadata_external_tuning.json and '
+ 'workload_metadata_self_tuning.json as references.'
+ 'Note that training will be interrupted at either the set maximum number '
+ 'of steps or the fixed workload maximum run time, whichever comes first. '
+ 'If your algorithm has a smaller per step time than our baselines '
+ 'you may want to increase the number of steps per workload.',
+)
+flags.DEFINE_string(
+ 'workloads',
+ None,
+ 'String representing a comma separated list of workload names.'
+ 'If not None, only run this workload, else run all workloads in workload_metadata_path.',
+)
+flags.DEFINE_string(
+ 'additional_requirements_path', None, 'Path to requirements.txt if any.'
)
-flags.DEFINE_string('additional_requirements_path',
- None,
- 'Path to requirements.txt if any.')
flags.DEFINE_integer(
- 'max_steps',
- None,
- 'Maximum number of steps to run. Must set flag enable_step_budget.'
- 'This flag takes precedence over the run_percentage flag.')
+ 'max_steps',
+ None,
+ 'Maximum number of steps to run. Must set flag enable_step_budget.'
+ 'This flag takes precedence over the run_percentage flag.',
+)
flags.DEFINE_bool(
- 'enable_step_budget',
- False,
- 'Flag that has to be explicitly set to override time budgets to step budget percentage.'
+ 'enable_step_budget',
+ False,
+ 'Flag that has to be explicitly set to override time budgets to step budget percentage.',
)
FLAGS = flags.FLAGS
def read_held_out_workloads(filename):
- with open(filename, "r") as f:
+ with open(filename, 'r') as f:
held_out_workloads = json.load(f)
return held_out_workloads
@@ -132,11 +142,13 @@ def kill_containers():
def gpu_is_active():
- output = subprocess.check_output([
+ output = subprocess.check_output(
+ [
'nvidia-smi',
'--query-gpu=utilization.gpu',
- '--format=csv,noheader,nounits'
- ])
+ '--format=csv,noheader,nounits',
+ ]
+ )
return any(int(x) > 0 for x in output.decode().splitlines())
@@ -151,7 +163,8 @@ def wait_until_container_not_running(sleep_interval=5 * 60):
gpu_last_active = datetime.datetime.now().timestamp()
if (datetime.datetime.now().timestamp() - gpu_last_active) > 45 * 60:
kill_containers(
- "Killing container: GPUs have been inactive > 45 minutes...")
+ 'Killing container: GPUs have been inactive > 45 minutes...'
+ )
time.sleep(sleep_interval)
return
@@ -167,7 +180,9 @@ def main(_):
hparam_start_index_flag = ''
hparam_end_index_flag = ''
if FLAGS.hparam_start_index:
- hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} '
+ hparam_start_index_flag = (
+ f'--hparam_start_index {FLAGS.hparam_start_index} '
+ )
if FLAGS.hparam_end_index:
hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} '
study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0
@@ -178,7 +193,9 @@ def main(_):
additional_requirements_path_flag = ''
if FLAGS.additional_requirements_path:
- additional_requirements_path_flag = f'--additional_requirements_path {FLAGS.additional_requirements_path} '
+ additional_requirements_path_flag = (
+ f'--additional_requirements_path {FLAGS.additional_requirements_path} '
+ )
submission_id = FLAGS.submission_id
@@ -188,7 +205,7 @@ def main(_):
rng_seed = struct.unpack('I', os.urandom(4))[0]
logging.info('Using RNG seed %d', rng_seed)
- rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id)))
+ rng_key = prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))
with open(FLAGS.workload_metadata_path) as f:
workload_metadata = json.load(f)
@@ -199,20 +216,24 @@ def main(_):
# Read heldout workloads
if FLAGS.held_out_workloads_config_path:
held_out_workloads = read_held_out_workloads(
- FLAGS.held_out_workloads_config_path)
+ FLAGS.held_out_workloads_config_path
+ )
workloads = workloads + held_out_workloads
# Filter workloads if explicit workloads specified
if FLAGS.workloads is not None:
workloads = list(
- filter(lambda x: x in FLAGS.workloads.split(','), workloads))
+ filter(lambda x: x in FLAGS.workloads.split(','), workloads)
+ )
if len(workloads) != len(FLAGS.workloads.split(',')):
unmatched_workloads = set(FLAGS.workloads.split(',')) - set(workloads)
raise ValueError(f'Invalid workload name {unmatched_workloads}')
rng_subkeys = prng.split(rng_key, num_studies)
- for study_index, rng_subkey in zip(range(study_start_index, study_end_index + 1), rng_subkeys):
+ for study_index, rng_subkey in zip(
+ range(study_start_index, study_end_index + 1), rng_subkeys
+ ):
print('-' * 100)
print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40)
print('-' * 100)
@@ -225,40 +246,46 @@ def main(_):
base_workload_name = get_base_workload_name(workload)
wait_until_container_not_running()
os.system(
- "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches
+ "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'"
+ ) # clear caches
print('=' * 100)
dataset = workload_metadata[base_workload_name]['dataset']
max_steps_flag = ''
if FLAGS.enable_step_budget:
- run_fraction = FLAGS.run_percentage / 100.
+ run_fraction = FLAGS.run_percentage / 100.0
if FLAGS.max_steps is None:
- max_steps = int(workload_metadata[base_workload_name]['max_steps'] *
- run_fraction)
+ max_steps = int(
+ workload_metadata[base_workload_name]['max_steps'] * run_fraction
+ )
else:
max_steps = FLAGS.max_steps
max_steps_flag = f'-m {max_steps}'
mount_repo_flag = ''
if FLAGS.local:
- mount_repo_flag = '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency '
- command = ('docker run -t -d -v /home/kasimbeg/data/:/data/ '
- '-v /home/kasimbeg/experiment_runs/:/experiment_runs '
- '-v /home/kasimbeg/experiment_runs/logs:/logs '
- f'{mount_repo_flag}'
- '--gpus all --ipc=host '
- f'{docker_image_url} '
- f'-d {dataset} '
- f'-f {framework} '
- f'-s {submission_path} '
- f'-w {workload} '
- f'-e {study_dir} '
- f'{max_steps_flag} '
- f'--num_tuning_trials {num_tuning_trials} '
- f'--rng_seed {run_seed} '
- f'{additional_requirements_path_flag}'
- '-c false '
- '-o true '
- '-i true ')
+ mount_repo_flag = (
+ '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency '
+ )
+ command = (
+ 'docker run -t -d -v /home/kasimbeg/data/:/data/ '
+ '-v /home/kasimbeg/experiment_runs/:/experiment_runs '
+ '-v /home/kasimbeg/experiment_runs/logs:/logs '
+ f'{mount_repo_flag}'
+ '--gpus all --ipc=host '
+ f'{docker_image_url} '
+ f'-d {dataset} '
+ f'-f {framework} '
+ f'-s {submission_path} '
+ f'-w {workload} '
+ f'-e {study_dir} '
+ f'{max_steps_flag} '
+ f'--num_tuning_trials {num_tuning_trials} '
+ f'--rng_seed {run_seed} '
+ f'{additional_requirements_path_flag}'
+ '-c false '
+ '-o true '
+ '-i true '
+ )
# Append tuning ruleset flags
tuning_ruleset_flags = ''
@@ -280,18 +307,19 @@ def main(_):
return_code = 0
if return_code == 0:
print(
- f'SUCCESS: container for {framework} {workload} launched successfully'
+ f'SUCCESS: container for {framework} {workload} launched successfully'
)
print(f'Command: {command}')
print(f'Results will be logged to {experiment_name}')
else:
print(
- f'Failed: container for {framework} {workload} failed with exit code {return_code}.'
+ f'Failed: container for {framework} {workload} failed with exit code {return_code}.'
)
print(f'Command: {command}')
wait_until_container_not_running()
os.system(
- "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches
+ "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'"
+ ) # clear caches
print('=' * 100)
diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md
index ffd56fbf3..a8e41f04b 100644
--- a/scoring/utils/slurm/README.md
+++ b/scoring/utils/slurm/README.md
@@ -1,26 +1,63 @@
+# Launching SLURM jobs with SBATCH
+
This folder contains a SLURM batch script that can be used to run jobs where each job corresponds to a training run on a given workload, training algorithm, random seed and tuning trial (if on external tuning ruleset).
To launch jobs:
-1) Generate a job config. The following command will generate a config.json.
-```
+
+1. Generate a job config. The following command will generate a config.json.
+
+```bash
python3 make_job_config.py \
--submission_path \
--tuning_search_space \
--experiment_dir $HOME/experiments/ \
--framework
```
-2) Save the config.json in the same directory you will run the sbatch script from.
-3) Check the sbatch script `run_jobs.sh`.
+
+2. Save the config.json in the same directory you will run the sbatch script from.
+3. Copy the example sbatch script `run_jobs.sh`.
+
- Set the task range to the number of tasks in the config.
+
```
#SBATCH --array=0-119
```
+
- Set the output and error logs directory for the SLURM logs.
+
```
#SBATCH --output=experiments///job_%A_%a.out
#SBATCH --error=experiments///job_%A_%a.err
```
-4) Submit a SLURM batch job by running:
+
+- Update the gcp project information, docker image, config file path and bucket to save the logs to as necessary:
+
+```
+REPO="us-central1-docker.pkg.dev"
+IMAGE="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_main"
+y | gcloud auth configure-docker $REPO
+docker pull $IMAGE
+# Job config (ATTENTION: you may want to modify this)
+config_file="$HOME/configs/pmap_job_config.json" # Replace with your config file path
+LOGS_BUCKET="algoperf-runs-internal"
+```
+
+4. Submit a SLURM batch job by running:
+
```
sbatch run_jobs.sh
```
+
+# Set up new SLURM cluster
+
+If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster).
+To set up the new cluster:
+
+1. [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart).
+2. Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml).
+3. Manually update the image:
+ 1. Create a VM from the Disk image created in the previous step.
+ 2. Install the NVIDIA container toolkit on the VM.
+ 3. Transfer the data from GCP bucket to `/opt/data`.
+ 4. Create a new disk image from the VM.
+4. Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml).
diff --git a/scoring/utils/slurm/algoperf_slurm_cluster.yaml b/scoring/utils/slurm/algoperf_slurm_cluster.yaml
new file mode 100644
index 000000000..073fe98cc
--- /dev/null
+++ b/scoring/utils/slurm/algoperf_slurm_cluster.yaml
@@ -0,0 +1,105 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+blueprint_name: algoperf-slurm-internal
+
+vars:
+ project_id: training-algorithms-external
+ deployment_name: algoperf-slurm-internal
+ region: europe-west4
+ zone: europe-west4-a
+ disk_size_gb: 3000
+ slurm_cluster_name: algoperf
+ image_name: algoperf-image-data-container-tkt
+
+# Recommended to use GCS backend for Terraform state
+# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state
+#
+# terraform_backend_defaults:
+# type: gcs
+# configuration:
+# bucket: <>
+
+deployment_groups:
+ - group: primary
+ modules:
+ - id: network
+ source: modules/network/vpc
+
+ - id: homefs
+ source: community/modules/file-system/nfs-server
+ use: [network]
+ settings:
+ local_mounts: [/home]
+ disk_size: 3000
+ zone: $(vars.zone)
+
+ - id: script
+ source: modules/scripts/startup-script
+ settings:
+
+ - group: cluster
+ modules:
+ - id: v100_nodeset
+ source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset
+ use:
+ - network
+ settings:
+ node_count_dynamic_max: 25 # set to 0 if you want node to live forever
+ region: $(vars.region)
+ zone: $(vars.zone)
+ enable_placement: false
+ bandwidth_tier: gvnic_enabled
+ machine_type: n1-standard-64
+ guest_accelerator:
+ - type: nvidia-tesla-v100
+ count: 8
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+
+ - id: v100_partition
+ source: community/modules/compute/schedmd-slurm-gcp-v6-partition
+ use: [v100_nodeset]
+ settings:
+ exclusive: false
+ partition_name: v100
+ is_default: true
+
+ - id: slurm_login
+ source: community/modules/scheduler/schedmd-slurm-gcp-v6-login
+ use: [network]
+ settings:
+ enable_login_public_ips: true
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+ zone: $(vars.zone)
+
+ - id: slurm_controller
+ source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller
+ use:
+ - network
+ - v100_partition
+ - homefs
+ - slurm_login
+ settings:
+ enable_controller_public_ips: true
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+ region: $(vars.region)
diff --git a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
new file mode 100644
index 000000000..f3b5be5dd
--- /dev/null
+++ b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
@@ -0,0 +1,131 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+---
+blueprint_name: algoperf-slurm-packer
+
+vars:
+ project_id: training-algorithms-external
+ deployment_name: algoperf-slurm-packer
+ region: europe-west4
+ zone: europe-west4-a
+ new_image:
+ family: algoperf-image
+ project: $(vars.project_id)
+ disk_size_gb: 3000
+ slurm_cluster_name: algoperf-packer
+
+# Recommended to use GCS backend for Terraform state
+# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state
+#
+# terraform_backend_defaults:
+# type: gcs
+# configuration:
+# bucket: <>
+
+deployment_groups:
+ - group: primary
+ modules:
+ - id: network
+ source: modules/network/vpc
+
+ - id: script
+ source: modules/scripts/startup-script
+ settings:
+ region: $(vars.region)
+ install_ansible: true
+ docker:
+ enabled: true
+ world_writable: true
+ # (TODO) Do I need this?
+ configure_ssh_host_patterns:
+ - 10.0.0.*
+ - 10.1.0.*
+ - 10.2.0.*
+ - 10.3.0.*
+ - 10.4.0.*
+ - 10.5.0.*
+ - 10.6.0.*
+ - 10.7.0.*
+ - $(vars.slurm_cluster_name)*
+ runners:
+ - type: shell
+ destination: install-ml-libraries.sh
+ content: |
+ #!/bin/bash
+ # this script is designed to execute on Slurm images published by SchedMD that:
+ # - are based on Debian distribution of Linux
+ # - have NVIDIA drivers pre-installed
+
+ set -e -o pipefail
+
+ echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list
+ apt-get update --allow-releaseinfo-change
+ apt-get install --assume-yes google-fast-socket
+
+ CONDA_BASE=/opt/conda
+
+ if [ -d $CONDA_BASE ]; then
+ exit 0
+ fi
+
+ DL_DIR=\$(mktemp -d)
+ cd $DL_DIR
+ curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh
+ HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE
+ cd -
+ rm -rf $DL_DIR
+ unset DL_DIR
+
+ source $CONDA_BASE/bin/activate base
+ conda init --system
+ conda config --system --set auto_activate_base False
+ # following channel ordering is important! use strict_priority!
+ conda config --system --set channel_priority strict
+ conda update -n base conda --yes
+
+ ### create a virtual environment for tensorflow
+ conda create -n tf python=3.11 --yes
+ conda activate tf
+ pip install tensorflow[and-cuda]==2.18.*
+ pip install tensorrt==10.6.*
+
+ ### create a virtual environment for pytorch
+ conda create -n pytorch python=3.11 --yes
+ conda activate pytorch
+ pip install torch torchvision torchaudio
+
+ - group: packer
+ modules:
+ - id: custom-image
+ source: modules/packer/custom-image
+ kind: packer
+ use:
+ - network
+ - script
+ settings:
+ # give VM a public IP to ensure startup script can reach public internet
+ # w/o new VPC
+ omit_external_ip: false
+ source_image_project_id: [schedmd-slurm-public]
+ # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family
+ source_image_family: slurm-gcp-6-8-debian-11
+ # You can find size of source image by using following command
+ # gcloud compute images describe-from-family --project schedmd-slurm-public
+ disk_size: $(vars.disk_size_gb)
+ image_family: $(vars.new_image.family)
+ # building this image does not require a GPU-enabled VM
+ machine_type: c2-standard-16
+ state_timeout: 300m
+ zone: $(vars.zone)
diff --git a/scoring/utils/slurm/config.json b/scoring/utils/slurm/config.json
new file mode 100644
index 000000000..cb49f9bf4
--- /dev/null
+++ b/scoring/utils/slurm/config.json
@@ -0,0 +1,106 @@
+{
+ "0": {
+ "framework": "jax",
+ "workload": "imagenet_resnet",
+ "dataset": "imagenet",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 411096763,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "1": {
+ "framework": "jax",
+ "workload": "imagenet_vit",
+ "dataset": "imagenet",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -1884713130,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "2": {
+ "framework": "jax",
+ "workload": "fastmri",
+ "dataset": "fastmri",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -214785144,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "3": {
+ "framework": "jax",
+ "workload": "ogbg",
+ "dataset": "ogbg",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -893097833,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "4": {
+ "framework": "jax",
+ "workload": "wmt",
+ "dataset": "wmt",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -1244182279,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "5": {
+ "framework": "jax",
+ "workload": "librispeech_deepspeech",
+ "dataset": "librispeech",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 1546003634,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "6": {
+ "framework": "jax",
+ "workload": "criteo1tb",
+ "dataset": "criteo1tb",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -2062333143,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "7": {
+ "framework": "jax",
+ "workload": "librispeech_conformer",
+ "dataset": "librispeech",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 409209730,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ }
+}
diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py
index 116e70459..f6a1ca158 100644
--- a/scoring/utils/slurm/make_job_config.py
+++ b/scoring/utils/slurm/make_job_config.py
@@ -6,60 +6,66 @@
--experiment_dir $HOME/experiments/ \
--framework
"""
+
import json
import os
-from absl import app
-from absl import flags
import jax
+from absl import app, flags
SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py'
-EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2'
+EXPERIMENT_DIR = (
+ 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2'
+)
TUNING_SEARCH_SPACE = None
FRAMEWORK = 'pytorch'
TUNING_RULESET = 'self'
flags.DEFINE_string(
- 'submission_path',
- SUBMISSION_PATH,
- 'Path to submission module relative to algorithmic-efficiency dir.')
+ 'submission_path',
+ SUBMISSION_PATH,
+ 'Path to submission module relative to algorithmic-efficiency dir.',
+)
+flags.DEFINE_string(
+ 'tuning_search_space',
+ TUNING_SEARCH_SPACE,
+ 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.',
+)
flags.DEFINE_string(
- 'tuning_search_space',
- TUNING_SEARCH_SPACE,
- 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.'
+ 'experiment_dir',
+ EXPERIMENT_DIR,
+ 'Path to experiment dir where logs will be saved.',
)
-flags.DEFINE_string('experiment_dir',
- EXPERIMENT_DIR,
- 'Path to experiment dir where logs will be saved.')
flags.DEFINE_enum(
- 'framework',
- FRAMEWORK,
- enum_values=['jax', 'pytorch'],
- help='Can be either pytorch or jax.')
+ 'framework',
+ FRAMEWORK,
+ enum_values=['jax', 'pytorch'],
+ help='Can be either pytorch or jax.',
+)
flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.')
flags.DEFINE_enum(
- 'tuning_ruleset',
- TUNING_RULESET,
- enum_values=['external', 'self'],
- help='Which tuning ruleset to score this submission on. Can be external or self.'
+ 'tuning_ruleset',
+ TUNING_RULESET,
+ enum_values=['external', 'self'],
+ help='Which tuning ruleset to score this submission on. Can be external or self.',
)
FLAGS = flags.FLAGS
-MIN_INT = -2**(31)
-MAX_INT = 2**(31) - 1
+MIN_INT = -(2 ** (31))
+MAX_INT = 2 ** (31) - 1
NUM_TUNING_TRIALS = 5 # For external tuning ruleset
NUM_STUDIES = 3
WORKLOADS = {
- "imagenet_resnet": {"dataset": "imagenet"},
- "imagenet_vit": {"dataset": "imagenet"},
- "fastmri": {"dataset": "fastmri"},
- "ogbg": {"dataset": "ogbg"},
- "wmt": {"dataset": "wmt"},
- "librispeech_deepspeech": {"dataset": "librispeech"},
- "criteo1tb": {"dataset": "criteo1tb"},
- "librispeech_conformer": {"dataset": "librispeech"}
+ 'imagenet_resnet': {'dataset': 'imagenet'},
+ 'imagenet_vit': {'dataset': 'imagenet'},
+ 'fastmri': {'dataset': 'fastmri'},
+ 'ogbg': {'dataset': 'ogbg'},
+ 'wmt': {'dataset': 'wmt'},
+ 'librispeech_deepspeech': {'dataset': 'librispeech'},
+ 'criteo1tb': {'dataset': 'criteo1tb'},
+ 'librispeech_conformer': {'dataset': 'librispeech'},
}
@@ -81,7 +87,7 @@ def main(_):
print(seed)
# Add job
job = {}
- study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
+ study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}')
job['framework'] = FLAGS.framework
job['workload'] = workload
job['dataset'] = WORKLOADS[workload]['dataset']
@@ -103,7 +109,7 @@ def main(_):
print(seed)
# Add job
job = {}
- study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
+ study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}')
job['framework'] = FLAGS.framework
job['workload'] = workload
job['dataset'] = WORKLOADS[workload]['dataset']
@@ -119,7 +125,7 @@ def main(_):
# Convert job array to dict with job indices
job_dict = {}
for i, job in enumerate(jobs):
- job_dict[f"{i}"] = job
+ job_dict[f'{i}'] = job
with open('config.json', 'w') as f:
json.dump(job_dict, f, indent=4)
diff --git a/scoring/utils/slurm/run_jobs.sh b/scoring/utils/slurm/run_jobs.sh
new file mode 100644
index 000000000..5fcf8f69e
--- /dev/null
+++ b/scoring/utils/slurm/run_jobs.sh
@@ -0,0 +1,83 @@
+#!/bin/bash
+
+#SBATCH --nodes=1 # give it a full node
+#SBATCH --ntasks-per-node=1
+#SBATCH --array=
+#SBATCH --partition=v100
+#SBATCH --gpus-per-node=8
+#SBATCH --exclusive #this will not allow other jobs to run on this cluster
+#SBATCH --output=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.out
+#SBATCH --error=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.err
+
+# Usage: sbatch .sh
+# This script reads config.json and launches a sbatch job using task
+# arrays where each job in the array corresponds to a training run
+# for a workload given a random seed and tuning trial index.
+# To generate the config.json use make_job_config.py.
+
+set -x
+
+# Pull docker image (ATTENTION: you may want to modify this)
+REPO=""
+IMAGE=""
+y | gcloud auth configure-docker $REPO
+docker pull $IMAGE
+# Job config (ATTENTION: you may want to modify this)
+config_file="" # Replace with your config file path
+LOGS_BUCKET="" # replace with your bucket used for logging
+
+
+# Function to read a JSON file and extract a value by key
+read_json_value() {
+ local json_file="$1"
+ local index="$2"
+ local key="$3"
+ local value=$(jq -r ".[\"$index\"].$key" "$json_file")
+ echo "$value"
+}
+
+# Check if jq is installed
+if ! command -v jq &> /dev/null
+then
+ echo "jq could not be found. Please install it."
+ exit 1
+fi
+
+TASK="$SLURM_ARRAY_TASK_ID"
+FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework")
+DATASET=$(read_json_value "$config_file" "$TASK" "dataset")
+SUBMISSION_PATH=$(read_json_value "$config_file" "$TASK" "submission_path")
+FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework")
+TUNING_SEARCH_SPACE=$(read_json_value "$config_file" "$TASK" "tuning_search_space")
+EXPERIMENT_DIR=$(read_json_value "$config_file" "$TASK" "experiment_dir")
+MAX_STEPS=$(read_json_value "$config_file" "$TASK" "max_steps")
+RNG_SEED=$(read_json_value "$config_file" "$TASK" "rng_seed")
+WORKLOAD=$(read_json_value "$config_file" "$TASK" "workload")
+HPARAM_START_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_start_index")
+HPARAM_END_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_end_index")
+NUM_TUNING_TRIALS=$(read_json_value "$config_file" "$TASK" "num_tuning_trials")
+TUNING_RULESET=$(read_json_value "$config_file" "$TASK" "tuning_ruleset")
+MAX_GLOBAL_STEPS=$(read_json_value "$config_file" "$MAX_GLOBAL_STEPS" "max_global_steps")
+
+docker run \
+ -v /opt/data/:/data/ \
+ -v $HOME/submissions_algorithms/:/algorithmic-efficiency/submissions_algorithms \
+ --gpus all \
+ --ipc=host \
+ $IMAGE \
+ -d $DATASET \
+ -f $FRAMEWORK \
+ -s $SUBMISSION_PATH \
+ -w $WORKLOAD \
+ -t $TUNING_SEARCH_SPACE \
+ -e $EXPERIMENT_DIR \
+ -c False \
+ -o True \
+ --rng_seed $RNG_SEED \
+ --hparam_start_index $HPARAM_START_INDEX \
+ --hparam_end_index $HPARAM_END_INDEX \
+ --num_tuning_trials $NUM_TUNING_TRIALS \
+ --tuning_ruleset $TUNING_RULESET \
+ --logs_bucket $LOGS_BUCKET \
+ -i true \
+ -r false
\ No newline at end of file
diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json
index c205d28b2..c7d4ae195 100644
--- a/scoring/utils/workload_metadata_external_tuning.json
+++ b/scoring/utils/workload_metadata_external_tuning.json
@@ -1,34 +1,34 @@
{
- "imagenet_resnet": {
- "max_steps": 186666,
- "dataset": "imagenet"
- },
- "imagenet_vit": {
- "max_steps": 186666,
- "dataset": "imagenet"
- },
- "fastmri": {
- "max_steps": 36189,
- "dataset": "fastmri"
- },
- "ogbg": {
- "max_steps": 80000,
- "dataset": "ogbg"
- },
- "wmt": {
- "max_steps": 133333,
- "dataset": "wmt"
- },
- "librispeech_deepspeech": {
- "max_steps": 48000,
- "dataset": "librispeech"
- },
- "criteo1tb": {
- "max_steps": 10666,
- "dataset": "criteo1tb"
- },
- "librispeech_conformer": {
- "max_steps": 80000,
- "dataset": "librispeech"
- }
- }
\ No newline at end of file
+ "imagenet_resnet": {
+ "max_steps": 186666,
+ "dataset": "imagenet"
+ },
+ "imagenet_vit": {
+ "max_steps": 186666,
+ "dataset": "imagenet"
+ },
+ "fastmri": {
+ "max_steps": 36189,
+ "dataset": "fastmri"
+ },
+ "ogbg": {
+ "max_steps": 80000,
+ "dataset": "ogbg"
+ },
+ "wmt": {
+ "max_steps": 133333,
+ "dataset": "wmt"
+ },
+ "librispeech_deepspeech": {
+ "max_steps": 48000,
+ "dataset": "librispeech"
+ },
+ "criteo1tb": {
+ "max_steps": 10666,
+ "dataset": "criteo1tb"
+ },
+ "librispeech_conformer": {
+ "max_steps": 80000,
+ "dataset": "librispeech"
+ }
+}
diff --git a/scoring/utils/workload_metadata_self_tuning.json b/scoring/utils/workload_metadata_self_tuning.json
index 105d5c52f..9d3e6b93d 100644
--- a/scoring/utils/workload_metadata_self_tuning.json
+++ b/scoring/utils/workload_metadata_self_tuning.json
@@ -1,34 +1,34 @@
{
- "imagenet_resnet": {
- "max_steps": 559998,
- "dataset": "imagenet"
- },
- "imagenet_vit": {
- "max_steps": 559998,
- "dataset": "imagenet"
- },
- "fastmri": {
- "max_steps": 108567,
- "dataset": "fastmri"
- },
- "ogbg": {
- "max_steps": 240000,
- "dataset": "ogbg"
- },
- "wmt": {
- "max_steps": 399999,
- "dataset": "wmt"
- },
- "librispeech_deepspeech": {
- "max_steps": 144000,
- "dataset": "librispeech"
- },
- "criteo1tb": {
- "max_steps": 31998,
- "dataset": "criteo1tb"
- },
- "librispeech_conformer": {
- "max_steps": 240000,
- "dataset": "librispeech"
- }
- }
\ No newline at end of file
+ "imagenet_resnet": {
+ "max_steps": 559998,
+ "dataset": "imagenet"
+ },
+ "imagenet_vit": {
+ "max_steps": 559998,
+ "dataset": "imagenet"
+ },
+ "fastmri": {
+ "max_steps": 108567,
+ "dataset": "fastmri"
+ },
+ "ogbg": {
+ "max_steps": 240000,
+ "dataset": "ogbg"
+ },
+ "wmt": {
+ "max_steps": 399999,
+ "dataset": "wmt"
+ },
+ "librispeech_deepspeech": {
+ "max_steps": 144000,
+ "dataset": "librispeech"
+ },
+ "criteo1tb": {
+ "max_steps": 31998,
+ "dataset": "criteo1tb"
+ },
+ "librispeech_conformer": {
+ "max_steps": 240000,
+ "dataset": "librispeech"
+ }
+}
diff --git a/submission_runner.py b/submission_runner.py
index 468a04c7c..8dc8589cb 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -17,41 +17,37 @@
import datetime
import gc
import importlib
-from inspect import signature
import itertools
import json
import os
import struct
import time
+from inspect import signature
from types import MappingProxyType
from typing import Any, Dict, Optional, Tuple
-from absl import app
-from absl import flags
-from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist
+from absl import app, flags, logging
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.set_visible_devices([], 'GPU')
-from algoperf import checkpoint_utils
-from algoperf import halton
-from algoperf import logger_utils
-from algoperf import random_utils as prng
-from algoperf import spec
-from algoperf.profiler import PassThroughProfiler
-from algoperf.profiler import Profiler
-from algoperf.pytorch_utils import pytorch_init
-from algoperf.pytorch_utils import pytorch_setup
-from algoperf.pytorch_utils import sync_ddp_time
-from algoperf.workloads import workloads
+from algoperf import checkpoint_utils, halton, logger_utils, spec # noqa: E402
+from algoperf import random_utils as prng # noqa: E402
+from algoperf.profiler import PassThroughProfiler, Profiler # noqa: E402
+from algoperf.pytorch_utils import ( # noqa: E402
+ pytorch_init,
+ pytorch_setup,
+ sync_ddp_time,
+)
+from algoperf.workloads import workloads # noqa: E402
# Environment variables
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings.
# disable only for deepspeech if it works fine for other workloads
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
@@ -62,106 +58,121 @@
WORKLOADS = workloads.WORKLOADS
flags.DEFINE_string(
- 'submission_path',
- None,
- 'The relative path of the Python file containing submission functions. '
- 'NOTE: the submission dir must have an __init__.py file!')
+ 'submission_path',
+ None,
+ 'The relative path of the Python file containing submission functions. '
+ 'NOTE: the submission dir must have an __init__.py file!',
+)
flags.DEFINE_string(
- 'workload',
- None,
- help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}'
+ 'workload',
+ None,
+ help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}',
)
flags.DEFINE_enum(
- 'tuning_ruleset',
- 'external',
- enum_values=['external', 'self'],
- help='Which tuning ruleset to use.')
+ 'tuning_ruleset',
+ 'external',
+ enum_values=['external', 'self'],
+ help='Which tuning ruleset to use.',
+)
flags.DEFINE_string(
- 'tuning_search_space',
- None,
- 'The path to the JSON file describing the external tuning search space.')
-flags.DEFINE_integer('num_tuning_trials',
- 1,
- 'The number of external hyperparameter trials to run.')
+ 'tuning_search_space',
+ None,
+ 'The path to the JSON file describing the external tuning search space.',
+)
+flags.DEFINE_integer(
+ 'num_tuning_trials', 1, 'The number of external hyperparameter trials to run.'
+)
flags.DEFINE_string('data_dir', '~/data', 'Dataset location.')
-flags.DEFINE_string('imagenet_v2_data_dir',
- None,
- 'Dataset location for ImageNet-v2.')
-flags.DEFINE_string('librispeech_tokenizer_vocab_path',
- '',
- 'Location to librispeech tokenizer.')
+flags.DEFINE_string(
+ 'imagenet_v2_data_dir', None, 'Dataset location for ImageNet-v2.'
+)
+flags.DEFINE_string(
+ 'librispeech_tokenizer_vocab_path', '', 'Location to librispeech tokenizer.'
+)
flags.DEFINE_enum(
- 'framework',
- None,
- enum_values=['jax', 'pytorch'],
- help='Whether to use Jax or Pytorch for the submission. Controls among '
- 'other things if the Jax or Numpy RNG library is used for RNG.')
+ 'framework',
+ None,
+ enum_values=['jax', 'pytorch'],
+ help='Whether to use Jax or Pytorch for the submission. Controls among '
+ 'other things if the Jax or Numpy RNG library is used for RNG.',
+)
flags.DEFINE_boolean(
- 'torch_compile',
- True,
- 'Whether to use `torch.compile` to JIT-compile PyTorch code. '
- 'This will only take effect when `framework`==pytorch.')
+ 'torch_compile',
+ True,
+ 'Whether to use `torch.compile` to JIT-compile PyTorch code. '
+ 'This will only take effect when `framework`==pytorch.',
+)
flags.DEFINE_string(
- 'experiment_dir',
- None,
- 'The root directory to store all experiments. '
- 'It is required and the directory should have '
- 'an absolute path rather than a relative path.')
+ 'experiment_dir',
+ None,
+ 'The root directory to store all experiments. '
+ 'It is required and the directory should have '
+ 'an absolute path rather than a relative path.',
+)
flags.DEFINE_string('experiment_name', None, 'Name of the experiment.')
flags.DEFINE_boolean(
- 'save_checkpoints',
- True,
- 'Whether or not to save checkpoints of the model and optimizer '
- 'at every eval and after training.')
+ 'save_checkpoints',
+ True,
+ 'Whether or not to save checkpoints of the model and optimizer '
+ 'at every eval and after training.',
+)
+flags.DEFINE_boolean(
+ 'save_intermediate_checkpoints',
+ True,
+ 'Whether to save any intermediate checkpoints. '
+ 'If False, it will only keep the latest checkpoint.',
+)
flags.DEFINE_boolean(
- 'save_intermediate_checkpoints',
- True,
- 'Whether to save any intermediate checkpoints. '
- 'If False, it will only keep the latest checkpoint.')
-flags.DEFINE_boolean('resume_last_run',
- None,
- 'Whether to resume the experiment from its last run.')
+ 'resume_last_run', None, 'Whether to resume the experiment from its last run.'
+)
flags.DEFINE_boolean(
- 'append_timestamp',
- False,
- 'If True, the current datetime will be appended to the experiment name. '
- 'Useful for guaranteeing a unique experiment dir for new runs.')
-flags.DEFINE_boolean('use_wandb',
- False,
- 'Whether to use Weights & Biases logging.')
+ 'append_timestamp',
+ False,
+ 'If True, the current datetime will be appended to the experiment name. '
+ 'Useful for guaranteeing a unique experiment dir for new runs.',
+)
+flags.DEFINE_boolean(
+ 'use_wandb', False, 'Whether to use Weights & Biases logging.'
+)
flags.DEFINE_boolean('profile', False, 'Whether to produce profiling output.')
-flags.DEFINE_integer('max_global_steps',
- None,
- 'Maximum number of update steps.')
+flags.DEFINE_integer(
+ 'max_global_steps', None, 'Maximum number of update steps.'
+)
flags.DEFINE_boolean(
- 'overwrite',
- False,
- 'Whether to overwrite the experiment with identical experiment_dir and'
- 'experiment_name.')
+ 'overwrite',
+ False,
+ 'Whether to overwrite the experiment with identical experiment_dir and'
+ 'experiment_name.',
+)
flags.DEFINE_integer(
- 'hparam_start_index',
- None,
- 'Start index to slice set of hyperparameters in tuning search space.')
+ 'hparam_start_index',
+ None,
+ 'Start index to slice set of hyperparameters in tuning search space.',
+)
flags.DEFINE_integer(
- 'hparam_end_index',
- None,
- 'End index to slice set of hyperparameters in tuning search space.')
+ 'hparam_end_index',
+ None,
+ 'End index to slice set of hyperparameters in tuning search space.',
+)
flags.DEFINE_integer(
- 'rng_seed',
- None,
- 'Value of rng seed. If None, a random seed will'
- 'be generated from hardware.')
-flags.DEFINE_boolean('set_pytorch_max_split_size',
- False,
- 'If true, set pytorch max_split_size_mb to 256')
+ 'rng_seed',
+ None,
+ 'Value of rng seed. If None, a random seed willbe generated from hardware.',
+)
+flags.DEFINE_boolean(
+ 'set_pytorch_max_split_size',
+ False,
+ 'If true, set pytorch max_split_size_mb to 256',
+)
flags.DEFINE_integer(
- 'pytorch_eval_num_workers',
- 0,
- 'Number of workers for ImageNet PyTorch evaluation data loaders.'
- 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
- 'in incorrect evals currently, see issues/732.')
+ 'pytorch_eval_num_workers',
+ 0,
+ 'Number of workers for ImageNet PyTorch evaluation data loaders.'
+ 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
+ 'in incorrect evals currently, see issues/732.',
+)
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
@@ -193,23 +204,23 @@ def _reset_cuda_mem():
def train_once(
- workload: spec.Workload,
- workload_name: str,
- global_batch_size: int,
- global_eval_batch_size: int,
- data_dir: str,
- imagenet_v2_data_dir: str,
- init_optimizer_state: spec.InitOptimizerFn,
- update_params: spec.UpdateParamsFn,
- data_selection: spec.DataSelectionFn,
- prepare_for_eval: Optional[spec.PrepareForEvalFn],
- hyperparameters: Optional[spec.Hyperparameters],
- rng_seed: int,
- rng: spec.RandomState,
- profiler: Profiler,
- max_global_steps: int = None,
- log_dir: Optional[str] = None,
- save_checkpoints: Optional[bool] = True
+ workload: spec.Workload,
+ workload_name: str,
+ global_batch_size: int,
+ global_eval_batch_size: int,
+ data_dir: str,
+ imagenet_v2_data_dir: str,
+ init_optimizer_state: spec.InitOptimizerFn,
+ update_params: spec.UpdateParamsFn,
+ data_selection: spec.DataSelectionFn,
+ prepare_for_eval: Optional[spec.PrepareForEvalFn],
+ hyperparameters: Optional[spec.Hyperparameters],
+ rng_seed: int,
+ rng: spec.RandomState,
+ profiler: Profiler,
+ max_global_steps: int = None,
+ log_dir: Optional[str] = None,
+ save_checkpoints: Optional[bool] = True,
) -> Tuple[spec.Timing, Dict[str, Any]]:
_reset_cuda_mem()
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
@@ -222,47 +233,44 @@ def train_once(
workload.eval_num_workers = FLAGS.pytorch_eval_num_workers
with profiler.profile('Initializing dataset'):
input_queue = workload._build_input_queue(
- data_rng,
- 'train',
- data_dir=data_dir,
- global_batch_size=global_batch_size)
+ data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size
+ )
logging.info('Initializing model.')
with profiler.profile('Initializing model'):
- dropout_rate = None
- aux_dropout_rate = None
- if hasattr(hyperparameters, 'dropout_rate'):
- dropout_rate = hyperparameters.dropout_rate
- if hasattr(hyperparameters, 'aux_dropout_rate'):
- aux_dropout_rate = hyperparameters.aux_dropout_rate
- model_params, model_state = workload.init_model_fn(
- model_init_rng, dropout_rate, aux_dropout_rate)
+ model_params, model_state = workload.init_model_fn(model_init_rng)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = [
- 'librispeech_conformer',
- 'ogbg',
- 'criteo1tb',
- 'imagenet_vit',
- 'librispeech_deepspeech'
+ 'librispeech_conformer',
+ 'ogbg',
+ 'criteo1tb',
+ 'imagenet_vit',
+ 'librispeech_deepspeech',
]
eager_backend_workloads = []
aot_eager_backend_workloads = []
loss_compilation_workloads = [
- 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt'
+ 'fastmri',
+ 'librispeech_deepspeech',
+ 'ogbg',
+ 'wmt',
]
base_workload = workloads.get_base_workload_name(workload_name)
if base_workload in compile_error_workloads:
logging.warning(
- 'These workloads cannot be fully compiled under current '
- 'PyTorch version. Proceeding without `torch.compile`.')
+ 'These workloads cannot be fully compiled under current '
+ 'PyTorch version. Proceeding without `torch.compile`.'
+ )
elif base_workload in eager_backend_workloads:
logging.warning(
- 'These workloads cannot be fully compiled under current '
- 'PyTorch version. Proceeding with `backend=eager`.')
+ 'These workloads cannot be fully compiled under current '
+ 'PyTorch version. Proceeding with `backend=eager`.'
+ )
model_params = torch.compile(model_params, backend='eager')
elif base_workload in aot_eager_backend_workloads:
logging.warning(
- 'These workloads cannot be fully compiled under current '
- 'PyTorch version. Proceeding with `backend=aot_eager`.')
+ 'These workloads cannot be fully compiled under current '
+ 'PyTorch version. Proceeding with `backend=aot_eager`.'
+ )
model_params = torch.compile(model_params, backend='aot_eager')
else:
logging.info('Performing `torch.compile`.')
@@ -271,11 +279,9 @@ def train_once(
workload.loss_fn = torch.compile(workload.loss_fn)
logging.info('Initializing optimizer.')
with profiler.profile('Initializing optimizer'):
- optimizer_state = init_optimizer_state(workload,
- model_params,
- model_state,
- hyperparameters,
- opt_init_rng)
+ optimizer_state = init_optimizer_state(
+ workload, model_params, model_state, hyperparameters, opt_init_rng
+ )
logging.info('Initializing metrics bundle.')
# Check if 'train_state' is in the function signature
@@ -283,15 +289,15 @@ def train_once(
# Bookkeeping.
train_state = {
- 'validation_goal_reached': False,
- 'test_goal_reached': False,
- 'is_time_remaining': True,
- 'last_eval_time': 0,
- 'training_complete': False,
- 'accumulated_submission_time': 0,
- 'accumulated_eval_time': 0,
- 'accumulated_logging_time': 0,
- 'last_step_end_time': None,
+ 'validation_goal_reached': False,
+ 'test_goal_reached': False,
+ 'is_time_remaining': True,
+ 'last_eval_time': 0,
+ 'training_complete': False,
+ 'accumulated_submission_time': 0,
+ 'accumulated_eval_time': 0,
+ 'accumulated_logging_time': 0,
+ 'last_step_end_time': None,
}
global_step = 0
eval_results = []
@@ -301,22 +307,25 @@ def train_once(
logging.info('Initializing checkpoint and logger.')
if log_dir is not None:
# If the checkpoint exists, load from the checkpoint.
- (optimizer_state,
- model_params,
- model_state,
- train_state,
- eval_results,
- global_step,
- preemption_count) = checkpoint_utils.maybe_restore_checkpoint(
- FLAGS.framework,
- optimizer_state,
- model_params,
- model_state,
- train_state,
- eval_results,
- global_step,
- preemption_count,
- checkpoint_dir=log_dir)
+ (
+ optimizer_state,
+ model_params,
+ model_state,
+ train_state,
+ eval_results,
+ global_step,
+ preemption_count,
+ ) = checkpoint_utils.maybe_restore_checkpoint(
+ FLAGS.framework,
+ optimizer_state,
+ model_params,
+ model_state,
+ train_state,
+ eval_results,
+ global_step,
+ preemption_count,
+ checkpoint_dir=log_dir,
+ )
meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json')
logging.info(f'Saving meta data to {meta_file_name}.')
meta_data = logger_utils.get_meta_data(workload, rng_seed)
@@ -326,9 +335,9 @@ def train_once(
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
metrics_logger = None
if RANK == 0:
- metrics_logger = logger_utils.set_up_loggers(log_dir,
- flags.FLAGS,
- hyperparameters)
+ metrics_logger = logger_utils.set_up_loggers(
+ log_dir, flags.FLAGS, hyperparameters
+ )
workload.attach_metrics_logger(metrics_logger)
global_start_time = get_time()
@@ -336,42 +345,50 @@ def train_once(
logging.info('Starting training loop.')
goals_reached = (
- train_state['validation_goal_reached'] and
- train_state['test_goal_reached'])
- while train_state['is_time_remaining'] and \
- not goals_reached and \
- not train_state['training_complete']:
-
+ train_state['validation_goal_reached'] and train_state['test_goal_reached']
+ )
+ while (
+ train_state['is_time_remaining']
+ and not goals_reached
+ and not train_state['training_complete']
+ ):
step_rng = prng.fold_in(rng, global_step)
- data_select_rng, update_rng, prep_eval_rng, eval_rng = \
- prng.split(step_rng, 4)
+ data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(
+ step_rng, 4
+ )
with profiler.profile('Data selection'):
- batch = data_selection(workload,
- input_queue,
- optimizer_state,
- model_params,
- model_state,
- hyperparameters,
- global_step,
- data_select_rng)
+ batch = data_selection(
+ workload,
+ input_queue,
+ optimizer_state,
+ model_params,
+ model_state,
+ hyperparameters,
+ global_step,
+ data_select_rng,
+ )
try:
with profiler.profile('Update parameters'):
optimizer_state, model_params, model_state = update_params(
- workload=workload,
- current_param_container=model_params,
- current_params_types=workload.model_params_types,
- model_state=model_state,
- hyperparameters=hyperparameters,
- batch=batch,
- loss_type=workload.loss_type,
- optimizer_state=optimizer_state,
- eval_results=eval_results,
- global_step=global_step,
- rng=update_rng,
- **({'train_state': MappingProxyType(train_state)}
- if needs_train_state else {}))
+ workload=workload,
+ current_param_container=model_params,
+ current_params_types=workload.model_params_types,
+ model_state=model_state,
+ hyperparameters=hyperparameters,
+ batch=batch,
+ loss_type=workload.loss_type,
+ optimizer_state=optimizer_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ rng=update_rng,
+ **(
+ {'train_state': MappingProxyType(train_state)}
+ if needs_train_state
+ else {}
+ ),
+ )
except spec.TrainingCompleteError:
train_state['training_complete'] = True
global_step += 1
@@ -381,121 +398,139 @@ def train_once(
train_step_end_time = get_time()
train_state['accumulated_submission_time'] += (
- train_step_end_time - train_state['last_step_end_time'])
+ train_step_end_time - train_state['last_step_end_time']
+ )
# Check if submission is eligible for an untimed eval.
- if ((train_step_end_time - train_state['last_eval_time']) >=
- workload.eval_period_time_sec or train_state['training_complete']):
-
+ if (
+ train_step_end_time - train_state['last_eval_time']
+ ) >= workload.eval_period_time_sec or train_state['training_complete']:
# Prepare for evaluation (timed).
if prepare_for_eval is not None:
-
with profiler.profile('Prepare for eval'):
del batch
prepare_for_eval_start_time = get_time()
optimizer_state, model_params, model_state = prepare_for_eval(
- workload=workload,
- current_param_container=model_params,
- current_params_types=workload.model_params_types,
- model_state=model_state,
- hyperparameters=hyperparameters,
- loss_type=workload.loss_type,
- optimizer_state=optimizer_state,
- eval_results=eval_results,
- global_step=global_step,
- rng=prep_eval_rng)
+ workload=workload,
+ current_param_container=model_params,
+ current_params_types=workload.model_params_types,
+ model_state=model_state,
+ hyperparameters=hyperparameters,
+ loss_type=workload.loss_type,
+ optimizer_state=optimizer_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ rng=prep_eval_rng,
+ )
prepare_for_eval_end_time = get_time()
# Update sumbission time.
train_state['accumulated_submission_time'] += (
- prepare_for_eval_end_time - prepare_for_eval_start_time)
+ prepare_for_eval_end_time - prepare_for_eval_start_time
+ )
# Check if time is remaining,
# use 1.5x the runtime budget for the self-tuning ruleset.
max_allowed_runtime_sec = (
- workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
- else 1.5 * workload.max_allowed_runtime_sec)
+ workload.max_allowed_runtime_sec
+ if FLAGS.tuning_ruleset == 'external'
+ else 1.5 * workload.max_allowed_runtime_sec
+ )
train_state['is_time_remaining'] = (
- train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
+ train_state['accumulated_submission_time'] < max_allowed_runtime_sec
+ )
# Eval if time is remaining (untimed).
if train_state['is_time_remaining']:
-
with profiler.profile('Evaluation'):
_reset_cuda_mem()
try:
eval_start_time = get_time()
- latest_eval_result = workload.eval_model(global_eval_batch_size,
- model_params,
- model_state,
- eval_rng,
- data_dir,
- imagenet_v2_data_dir,
- global_step)
+ latest_eval_result = workload.eval_model(
+ global_eval_batch_size,
+ model_params,
+ model_state,
+ eval_rng,
+ data_dir,
+ imagenet_v2_data_dir,
+ global_step,
+ )
# Check if targets reached.
# Note that this is one of the stopping conditions for the length of
# a training run. To score the run we only consider the time
# to validation target retrospectively.
train_state['validation_goal_reached'] = (
- workload.has_reached_validation_target(latest_eval_result) or
- train_state['validation_goal_reached'])
+ workload.has_reached_validation_target(latest_eval_result)
+ or train_state['validation_goal_reached']
+ )
train_state['test_goal_reached'] = (
- workload.has_reached_test_target(latest_eval_result) or
- train_state['test_goal_reached'])
+ workload.has_reached_test_target(latest_eval_result)
+ or train_state['test_goal_reached']
+ )
goals_reached = (
- train_state['validation_goal_reached'] and
- train_state['test_goal_reached'])
+ train_state['validation_goal_reached']
+ and train_state['test_goal_reached']
+ )
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time
# Accumulate eval time.
- train_state[
- 'accumulated_eval_time'] += eval_end_time - eval_start_time
+ train_state['accumulated_eval_time'] += (
+ eval_end_time - eval_start_time
+ )
# Add times to eval results for logging.
- latest_eval_result['score'] = (
- train_state['accumulated_submission_time'])
- latest_eval_result[
- 'total_duration'] = eval_end_time - global_start_time
+ latest_eval_result['score'] = train_state[
+ 'accumulated_submission_time'
+ ]
+ latest_eval_result['total_duration'] = (
+ eval_end_time - global_start_time
+ )
latest_eval_result['accumulated_submission_time'] = train_state[
- 'accumulated_submission_time']
+ 'accumulated_submission_time'
+ ]
latest_eval_result['accumulated_eval_time'] = train_state[
- 'accumulated_eval_time']
+ 'accumulated_eval_time'
+ ]
latest_eval_result['accumulated_logging_time'] = train_state[
- 'accumulated_logging_time']
+ 'accumulated_logging_time'
+ ]
time_since_start = latest_eval_result['total_duration']
- logging.info(f'Time since start: {time_since_start:.2f}s, '
- f'\tStep: {global_step}, \t{latest_eval_result}')
+ logging.info(
+ f'Time since start: {time_since_start:.2f}s, '
+ f'\tStep: {global_step}, \t{latest_eval_result}'
+ )
eval_results.append((global_step, latest_eval_result))
logging_start_time = get_time()
if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
- latest_eval_result,
- global_step=global_step,
- preemption_count=preemption_count,
- is_eval=True,
+ latest_eval_result,
+ global_step=global_step,
+ preemption_count=preemption_count,
+ is_eval=True,
)
if save_checkpoints:
checkpoint_utils.save_checkpoint(
- framework=FLAGS.framework,
- optimizer_state=optimizer_state,
- model_params=model_params,
- model_state=model_state,
- train_state=train_state,
- eval_results=eval_results,
- global_step=global_step,
- preemption_count=preemption_count,
- checkpoint_dir=log_dir,
- save_intermediate_checkpoints=FLAGS
- .save_intermediate_checkpoints)
+ framework=FLAGS.framework,
+ optimizer_state=optimizer_state,
+ model_params=model_params,
+ model_state=model_state,
+ train_state=train_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ preemption_count=preemption_count,
+ checkpoint_dir=log_dir,
+ save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints,
+ )
logging_end_time = get_time()
train_state['accumulated_logging_time'] += (
- logging_end_time - logging_start_time)
+ logging_end_time - logging_start_time
+ )
_reset_cuda_mem()
@@ -503,8 +538,9 @@ def train_once(
logging.exception(f'Eval step {global_step} error.\n')
if 'out of memory' in str(e):
logging.warning(
- 'Error: GPU out of memory during eval during step '
- f'{global_step}, error : {str(e)}.')
+ 'Error: GPU out of memory during eval during step '
+ f'{global_step}, error : {str(e)}.'
+ )
_reset_cuda_mem()
train_state['last_step_end_time'] = get_time()
@@ -513,41 +549,45 @@ def train_once(
if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
- {'score': train_state['accumulated_submission_time']},
- global_step=global_step,
- preemption_count=preemption_count)
+ {'score': train_state['accumulated_submission_time']},
+ global_step=global_step,
+ preemption_count=preemption_count,
+ )
metrics_logger.finish()
if save_checkpoints:
checkpoint_utils.save_checkpoint(
- framework=FLAGS.framework,
- optimizer_state=optimizer_state,
- model_params=model_params,
- model_state=model_state,
- train_state=train_state,
- eval_results=eval_results,
- global_step=global_step,
- preemption_count=preemption_count,
- checkpoint_dir=log_dir,
- save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints)
+ framework=FLAGS.framework,
+ optimizer_state=optimizer_state,
+ model_params=model_params,
+ model_state=model_state,
+ train_state=train_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ preemption_count=preemption_count,
+ checkpoint_dir=log_dir,
+ save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints,
+ )
return train_state['accumulated_submission_time'], metrics
-def score_submission_on_workload(workload: spec.Workload,
- workload_name: str,
- submission_path: str,
- data_dir: str,
- tuning_ruleset: str,
- profiler: Optional[Profiler] = None,
- max_global_steps: Optional[int] = None,
- imagenet_v2_data_dir: Optional[str] = None,
- tuning_search_space: Optional[str] = None,
- num_tuning_trials: Optional[int] = None,
- log_dir: Optional[str] = None,
- save_checkpoints: Optional[bool] = True,
- hparam_start_index: Optional[bool] = None,
- hparam_end_index: Optional[bool] = None,
- rng_seed: Optional[int] = None):
+def score_submission_on_workload(
+ workload: spec.Workload,
+ workload_name: str,
+ submission_path: str,
+ data_dir: str,
+ tuning_ruleset: str,
+ profiler: Optional[Profiler] = None,
+ max_global_steps: Optional[int] = None,
+ imagenet_v2_data_dir: Optional[str] = None,
+ tuning_search_space: Optional[str] = None,
+ num_tuning_trials: Optional[int] = None,
+ log_dir: Optional[str] = None,
+ save_checkpoints: Optional[bool] = True,
+ hparam_start_index: Optional[bool] = None,
+ hparam_end_index: Optional[bool] = None,
+ rng_seed: Optional[int] = None,
+):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
if imagenet_v2_data_dir:
@@ -571,18 +611,21 @@ def score_submission_on_workload(workload: spec.Workload,
n_gpus = max(N_GPUS, jax.local_device_count())
if global_batch_size % n_gpus != 0:
raise ValueError(
- f'The global batch size ({global_batch_size}) has to be divisible by '
- f'the number of GPUs ({n_gpus}).')
+ f'The global batch size ({global_batch_size}) has to be divisible by '
+ f'the number of GPUs ({n_gpus}).'
+ )
if hasattr(submission_module, 'get_eval_batch_size'):
# If the user specifies the eval batch size, use the provided one.
global_eval_batch_size = submission_module.get_eval_batch_size(
- workload_name)
+ workload_name
+ )
else:
global_eval_batch_size = workload.eval_batch_size
if global_eval_batch_size % n_gpus != 0:
raise ValueError(
- f'The global eval batch size ({global_eval_batch_size}) has to be '
- f'divisible by the number of GPUs ({n_gpus}).')
+ f'The global eval batch size ({global_eval_batch_size}) has to be '
+ f'divisible by the number of GPUs ({n_gpus}).'
+ )
if tuning_ruleset == 'external':
# If the submission runner is responsible for hyperparameter tuning, load in
@@ -590,15 +633,18 @@ def score_submission_on_workload(workload: spec.Workload,
# settings from it.
if tuning_search_space is None:
raise ValueError(
- 'Must provide a tuning search space JSON file when using external '
- 'tuning.')
+ 'Must provide a tuning search space JSON file when using external '
+ 'tuning.'
+ )
with open(tuning_search_space, 'r', encoding='UTF-8') as search_space_file:
tuning_search_space = halton.generate_search(
- json.load(search_space_file), num_tuning_trials)
+ json.load(search_space_file), num_tuning_trials
+ )
all_timings = {}
all_metrics = {}
tuning_search_space_iter = itertools.islice(
- enumerate(tuning_search_space), hparam_start_index, hparam_end_index)
+ enumerate(tuning_search_space), hparam_start_index, hparam_end_index
+ )
for hi, hyperparameters in tuning_search_space_iter:
# Generate a new seed from hardware sources of randomness for each trial.
if not rng_seed:
@@ -622,25 +668,31 @@ def score_submission_on_workload(workload: spec.Workload,
# If existing hyperparameter exists, use saved
# hyperparameters for consistency.
- hyperparameters = logger_utils.write_hparams(hyperparameters,
- tuning_dir_name)
+ hyperparameters = logger_utils.write_hparams(
+ hyperparameters, tuning_dir_name
+ )
tuning_search_space[hi] = hyperparameters
with profiler.profile('Train'):
- timing, metrics = train_once(workload, workload_name,
- global_batch_size,
- global_eval_batch_size,
- data_dir, imagenet_v2_data_dir,
- init_optimizer_state,
- update_params, data_selection,
- prepare_for_eval,
- hyperparameters,
- rng_seed,
- rng,
- profiler,
- max_global_steps,
- tuning_dir_name,
- save_checkpoints=save_checkpoints,)
+ timing, metrics = train_once(
+ workload,
+ workload_name,
+ global_batch_size,
+ global_eval_batch_size,
+ data_dir,
+ imagenet_v2_data_dir,
+ init_optimizer_state,
+ update_params,
+ data_selection,
+ prepare_for_eval,
+ hyperparameters,
+ rng_seed,
+ rng,
+ profiler,
+ max_global_steps,
+ tuning_dir_name,
+ save_checkpoints=save_checkpoints,
+ )
all_timings[hi] = timing
all_metrics[hi] = metrics
logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}')
@@ -654,7 +706,8 @@ def score_submission_on_workload(workload: spec.Workload,
else:
if tuning_search_space is not None:
raise ValueError(
- 'Cannot provide a tuning search space when using self tuning.')
+ 'Cannot provide a tuning search space when using self tuning.'
+ )
if not rng_seed:
rng_seed = struct.unpack('q', os.urandom(8))[0]
rng = prng.PRNGKey(rng_seed)
@@ -666,11 +719,24 @@ def score_submission_on_workload(workload: spec.Workload,
logger_utils.makedir(log_dir)
with profiler.profile('Train'):
score, _ = train_once(
- workload, workload_name, global_batch_size, global_eval_batch_size,
- data_dir, imagenet_v2_data_dir,
- init_optimizer_state, update_params, data_selection, prepare_for_eval,
- None, rng_seed, rng, profiler, max_global_steps, log_dir,
- save_checkpoints=save_checkpoints)
+ workload,
+ workload_name,
+ global_batch_size,
+ global_eval_batch_size,
+ data_dir,
+ imagenet_v2_data_dir,
+ init_optimizer_state,
+ update_params,
+ data_selection,
+ prepare_for_eval,
+ None,
+ rng_seed,
+ rng,
+ profiler,
+ max_global_steps,
+ log_dir,
+ save_checkpoints=save_checkpoints,
+ )
return score
@@ -694,59 +760,66 @@ def main(_):
# TODO: remove once issue resolved.
if FLAGS.pytorch_eval_num_workers != 0:
logging.warning(
- 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
- 'in incorrect evals currently, see issues/732.')
+ 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
+ 'in incorrect evals currently, see issues/732.'
+ )
workload_metadata = WORKLOADS[FLAGS.workload]
if base_workload in [
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'imagenet_vit',
- 'criteo1tb'
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'imagenet_vit',
+ 'criteo1tb',
]:
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'
# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
- BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + f'_{FLAGS.framework}',
- 'workload.py')
+ BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + f'_{FLAGS.framework}',
+ 'workload.py',
+ )
workload_init_kwargs = {}
if FLAGS.librispeech_tokenizer_vocab_path:
workload_init_kwargs['tokenizer_vocab_path'] = (
- FLAGS.librispeech_tokenizer_vocab_path)
+ FLAGS.librispeech_tokenizer_vocab_path
+ )
workload = workloads.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs=workload_init_kwargs)
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs=workload_init_kwargs,
+ )
experiment_name = FLAGS.experiment_name
if experiment_name and FLAGS.append_timestamp:
experiment_name += datetime.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S')
- logging_dir_path = logger_utils.get_log_dir(FLAGS.experiment_dir,
- FLAGS.workload,
- FLAGS.framework,
- experiment_name,
- FLAGS.resume_last_run,
- FLAGS.overwrite)
+ logging_dir_path = logger_utils.get_log_dir(
+ FLAGS.experiment_dir,
+ FLAGS.workload,
+ FLAGS.framework,
+ experiment_name,
+ FLAGS.resume_last_run,
+ FLAGS.overwrite,
+ )
score = score_submission_on_workload(
- workload=workload,
- workload_name=FLAGS.workload,
- submission_path=FLAGS.submission_path,
- data_dir=FLAGS.data_dir,
- tuning_ruleset=FLAGS.tuning_ruleset,
- profiler=profiler,
- max_global_steps=FLAGS.max_global_steps,
- imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir,
- tuning_search_space=FLAGS.tuning_search_space,
- num_tuning_trials=FLAGS.num_tuning_trials,
- log_dir=logging_dir_path,
- save_checkpoints=FLAGS.save_checkpoints,
- hparam_start_index=FLAGS.hparam_start_index,
- hparam_end_index=FLAGS.hparam_end_index,
- rng_seed=FLAGS.rng_seed)
+ workload=workload,
+ workload_name=FLAGS.workload,
+ submission_path=FLAGS.submission_path,
+ data_dir=FLAGS.data_dir,
+ tuning_ruleset=FLAGS.tuning_ruleset,
+ profiler=profiler,
+ max_global_steps=FLAGS.max_global_steps,
+ imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir,
+ tuning_search_space=FLAGS.tuning_search_space,
+ num_tuning_trials=FLAGS.num_tuning_trials,
+ log_dir=logging_dir_path,
+ save_checkpoints=FLAGS.save_checkpoints,
+ hparam_start_index=FLAGS.hparam_start_index,
+ hparam_end_index=FLAGS.hparam_end_index,
+ rng_seed=FLAGS.rng_seed,
+ )
logging.info(f'Final {FLAGS.workload} score: {score}')
if FLAGS.profile:
diff --git a/submissions/submission_checker.py b/submissions/submission_checker.py
index ab657c0f0..f8af9fb52 100644
--- a/submissions/submission_checker.py
+++ b/submissions/submission_checker.py
@@ -29,7 +29,6 @@
import argparse
import logging
import os
-import subprocess
SELF_TUNING = 'self_tuning'
EXTERNAL_TUNING = 'external_tuning'
@@ -41,7 +40,8 @@ def _check_ruleset_subdirs(submission_dir):
contents = os.listdir(submission_dir)
if not ((EXTERNAL_TUNING in contents) or (SELF_TUNING in contents)):
logging.info(
- f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.')
+ f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.'
+ )
return False
return True
@@ -54,7 +54,7 @@ def _check_submission_module(submission_dir):
contents = os.listdir(os.path.join(root, submission_dir))
if SUBMISSION_MODULE not in contents:
logging.info(
- f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}'
+ f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}'
)
return False
return True
@@ -68,7 +68,7 @@ def _check_tuning_search_space_file(submission_dir):
contents = os.listdir(os.path.join(root, submission_dir))
if TUNING_SEARCH_SPACE_FILENAME not in contents:
logging.info(
- f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}'
+ f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}'
)
return False
return True
@@ -76,18 +76,22 @@ def _check_tuning_search_space_file(submission_dir):
def run_checks(submission_dir):
"""Top-level checker function.
- Call individual checkers from this function.
- """
+ Call individual checkers from this function.
+ """
logging.info('Running repository checks.')
# Execute checks
contains_ruleset_subdirs = _check_ruleset_subdirs(submission_dir)
contains_submission_module = _check_submission_module(submission_dir)
contains_tuning_search_space_file = _check_tuning_search_space_file(
- submission_dir)
+ submission_dir
+ )
- if not (contains_ruleset_subdirs and contains_submission_module and
- contains_tuning_search_space_file):
+ if not (
+ contains_ruleset_subdirs
+ and contains_submission_module
+ and contains_tuning_search_space_file
+ ):
logging.info('TESTS FAILED.')
return False
@@ -98,16 +102,17 @@ def run_checks(submission_dir):
def get_parser():
"""Parse commandline."""
parser = argparse.ArgumentParser(
- description='Checks for submission folder for AlgoPerf',)
+ description='Checks for submission folder for AlgoPerf',
+ )
parser.add_argument(
- 'folder',
- type=str,
- help='the folder for a submission package.',
+ 'folder',
+ type=str,
+ help='the folder for a submission package.',
)
parser.add_argument(
- '--log_output',
- type=str,
- default='submission_checker.log',
+ '--log_output',
+ type=str,
+ default='submission_checker.log',
)
return parser
@@ -118,7 +123,7 @@ def main():
logging.basicConfig(filename=args.log_output, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
- formatter = logging.Formatter("%(levelname)s - %(message)s")
+ formatter = logging.Formatter('%(levelname)s - %(message)s')
logging.getLogger().handlers[0].setFormatter(formatter)
logging.getLogger().handlers[1].setFormatter(formatter)
diff --git a/submissions/template/submission.py b/submissions/template/submission.py
index a4fdc62b4..db6900afd 100644
--- a/submissions/template/submission.py
+++ b/submissions/template/submission.py
@@ -4,96 +4,102 @@
and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions
for guidelines.
"""
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
from algoperf import spec
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule.
- Returns:
- optimizer state
- optimizer_update_fn
- """
+ Returns: spec.OptimizerState initialized optimizer state
+ """
pass
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
+ """
+ Returns:
+ spec.OptimizerState: new optimizer state
+ spec.ParameterTypeTree: new params
+ new_model_state: new model state
"""
- Returns:
- (new_optimizer_state, update_fn)
- new_params
- new_model_state
- """
pass
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
+ """
+ Returns:
+ new_optimizer_state
+ new_params
+ new_model_state
"""
- Returns:
- new_optimizer_state
- new_params
- new_model_state
- """
pass
def get_batch_size(workload_name):
"""
- Gets batch size for workload.
- Note that these batch sizes only apply during training and not during evals.
- Args:
- workload_name (str): Valid workload_name values are: "wmt", "ogbg",
- "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit",
- "librispeech_deepspeech", "librispeech_conformer" or any of the
- variants.
- Returns:
- int: batch_size
- Raises:
- ValueError: If workload_name is not handled.
- """
+ Gets batch size for workload.
+ Note that these batch sizes only apply during training and not during evals.
+ Args:
+ workload_name (str): Valid workload_name values are: "wmt", "ogbg",
+ "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit",
+ "librispeech_deepspeech", "librispeech_conformer" or any of the
+ variants.
+ Returns:
+ int: batch_size
+ Raises:
+ ValueError: If workload_name is not handled.
+ """
pass
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
- Each element of the queue is a batch of training examples and labels.
- Tip:
- If you would just like the next batch from the input queue return next(input_queue).
+ Each element of the queue is a batch of training examples and labels.
+ Tip:
+ If you would just like the next batch from the input queue return next(input_queue).
- Returns:
- batch: next batch of input data
- """
+ Returns:
+ batch: next batch of input data
+ """
pass
diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py
index 9de61a2a5..48f658d06 100644
--- a/tests/modeldiffs/criteo1tb/compare.py
+++ b/tests/modeldiffs/criteo1tb/compare.py
@@ -8,10 +8,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -54,31 +56,34 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py
index f1897d16f..897920bac 100644
--- a/tests/modeldiffs/criteo1tb_embed_init/compare.py
+++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py
@@ -8,10 +8,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -53,31 +55,34 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py
index 5aad3cc67..db1dec601 100644
--- a/tests/modeldiffs/criteo1tb_layernorm/compare.py
+++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py
@@ -8,10 +8,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -65,31 +67,34 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- # mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ # mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py
index 169b1cdf4..4851a8ad4 100644
--- a/tests/modeldiffs/criteo1tb_resnet/compare.py
+++ b/tests/modeldiffs/criteo1tb_resnet/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -65,9 +67,9 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
init_fake_batch_size = 2
@@ -80,23 +82,26 @@ def sd_transform(sd):
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py
index 52241fd3a..a74159028 100644
--- a/tests/modeldiffs/diff.py
+++ b/tests/modeldiffs/diff.py
@@ -1,21 +1,23 @@
-from flax import jax_utils
-from flax.core import FrozenDict
import jax
import numpy as np
import torch
+from flax import jax_utils
+from flax.core import FrozenDict
-from tests.modeldiffs.torch2jax_utils import Torch2Jax
-from tests.modeldiffs.torch2jax_utils import value_transform
+from tests.modeldiffs.torch2jax_utils import Torch2Jax, value_transform
-#pylint: disable=dangerous-default-value
-def torch2jax(jax_workload,
- pytorch_workload,
- key_transform=None,
- sd_transform=None,
- init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0)):
- jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0),
- **init_kwargs)
+# pylint: disable=dangerous-default-value
+def torch2jax(
+ jax_workload,
+ pytorch_workload,
+ key_transform=None,
+ sd_transform=None,
+ init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0),
+):
+ jax_params, model_state = jax_workload.init_model_fn(
+ jax.random.PRNGKey(0), **init_kwargs
+ )
pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs)
if isinstance(jax_params, dict):
jax_params = FrozenDict(jax_params)
@@ -24,8 +26,9 @@ def torch2jax(jax_workload,
model_state = jax_utils.unreplicate(model_state)
if isinstance(
- pytorch_model,
- (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
+ pytorch_model,
+ (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
+ ):
pytorch_model = pytorch_model.module
# Map and copy params of pytorch_model to jax_model.
t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params)
@@ -39,22 +42,24 @@ def torch2jax(jax_workload,
return jax_params, model_state, pytorch_model
-def out_diff(jax_workload,
- pytorch_workload,
- jax_model_kwargs,
- pytorch_model_kwargs,
- key_transform=None,
- sd_transform=None,
- out_transform=None):
- jax_params, model_state, pytorch_model = torch2jax(jax_workload,
- pytorch_workload,
- key_transform,
- sd_transform)
- out_p, _ = pytorch_workload.model_fn(params=pytorch_model,
- **pytorch_model_kwargs)
- out_j, _ = jax_workload.model_fn(params=jax_params,
- model_state=model_state,
- **jax_model_kwargs)
+def out_diff(
+ jax_workload,
+ pytorch_workload,
+ jax_model_kwargs,
+ pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=None,
+ out_transform=None,
+):
+ jax_params, model_state, pytorch_model = torch2jax(
+ jax_workload, pytorch_workload, key_transform, sd_transform
+ )
+ out_p, _ = pytorch_workload.model_fn(
+ params=pytorch_model, **pytorch_model_kwargs
+ )
+ out_j, _ = jax_workload.model_fn(
+ params=jax_params, model_state=model_state, **jax_model_kwargs
+ )
if out_transform is not None:
out_p = out_transform(out_p)
out_j = out_transform(out_j)
@@ -67,15 +72,16 @@ def out_diff(jax_workload,
class ModelDiffRunner:
-
- def __init__(self,
- jax_workload,
- pytorch_workload,
- jax_model_kwargs,
- pytorch_model_kwargs,
- key_transform=None,
- sd_transform=None,
- out_transform=None) -> None:
+ def __init__(
+ self,
+ jax_workload,
+ pytorch_workload,
+ jax_model_kwargs,
+ pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=None,
+ out_transform=None,
+ ) -> None:
"""
Initializes the instance based on diffing logic.
@@ -83,7 +89,7 @@ def __init__(self,
jax_workload: Workload implementation using JAX.
pytorch_workload: Workload implementation using PyTorch.
jax_model_kwargs: Arguments to be used for model_fn in jax workload.
- pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
+ pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
workload.
key_transform: Transformation function for keys.
sd_transform: Transformation function for State Dictionary.
@@ -99,10 +105,12 @@ def __init__(self,
self.out_transform = out_transform
def run(self):
- out_diff(self.jax_workload,
- self.pytorch_workload,
- self.jax_model_kwargs,
- self.pytorch_model_kwargs,
- self.key_transform,
- self.sd_transform,
- self.out_transform)
+ out_diff(
+ self.jax_workload,
+ self.pytorch_workload,
+ self.jax_model_kwargs,
+ self.pytorch_model_kwargs,
+ self.key_transform,
+ self.sd_transform,
+ self.out_transform,
+ )
diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py
index c1a349cec..6a82bfb58 100644
--- a/tests/modeldiffs/fastmri/compare.py
+++ b/tests/modeldiffs/fastmri/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRIWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRIWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -34,7 +35,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -64,22 +65,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py
index f26ad185e..8ad47bcae 100644
--- a/tests/modeldiffs/fastmri_layernorm/compare.py
+++ b/tests/modeldiffs/fastmri_layernorm/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRILayerNormWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRILayerNormWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRILayerNormWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRILayerNormWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -35,7 +36,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -71,22 +72,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py
index 42789539b..f6d5c5074 100644
--- a/tests/modeldiffs/fastmri_model_size/compare.py
+++ b/tests/modeldiffs/fastmri_model_size/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRIModelSizeWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRIModelSizeWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIModelSizeWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIModelSizeWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -34,7 +35,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -64,22 +65,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py
index 13ecb890c..714a025b3 100644
--- a/tests/modeldiffs/fastmri_tanh/compare.py
+++ b/tests/modeldiffs/fastmri_tanh/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRITanhWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRITanhWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRITanhWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRITanhWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -34,7 +35,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -64,22 +65,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py
index 59ab45555..e43cd069e 100644
--- a/tests/modeldiffs/imagenet_resnet/compare.py
+++ b/tests/modeldiffs/imagenet_resnet/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetWorkload as JaxWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -53,11 +55,13 @@ def sd_transform(sd):
c += 1
new_key = (f'BottleneckResNetBlock_{c}',) + k[2:]
if 'Sequential' in ''.join(new_key):
- new_key = tuple([
+ new_key = tuple(
+ [
(i.replace('_0', '_proj') if 'BatchNorm' in i or 'Conv' in i else i)
for i in new_key
if 'Sequential' not in i
- ])
+ ]
+ )
sd[new_key] = sd[k]
del sd[k]
elif 'BatchNorm' in k[0] or 'Conv' in k[0]:
@@ -81,22 +85,25 @@ def sd_transform(sd):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py
index 07510ad70..a92712ddc 100644
--- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py
+++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py
@@ -7,13 +7,14 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetGELUWorkload as JaxWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetGELUWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetGELUWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetGELUWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.imagenet_resnet.compare import key_transform
-from tests.modeldiffs.imagenet_resnet.compare import sd_transform
+from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
@@ -28,22 +29,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py
index 8246d17a2..bbbfd082b 100644
--- a/tests/modeldiffs/imagenet_resnet/silu_compare.py
+++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py
@@ -7,13 +7,14 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetSiLUWorkload as JaxWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetSiLUWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetSiLUWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetSiLUWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.imagenet_resnet.compare import key_transform
-from tests.modeldiffs.imagenet_resnet.compare import sd_transform
+from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
@@ -28,22 +29,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py
index b4ca7d8ec..84282d4be 100644
--- a/tests/modeldiffs/imagenet_vit/compare.py
+++ b/tests/modeldiffs/imagenet_vit/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetVitWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -40,16 +42,16 @@ def key_transform(k):
if attention:
if pool_head:
i = {
- 'Linear_0': 'query',
- 'Linear_1': 'key_value',
- 'Linear_2': 'out',
+ 'Linear_0': 'query',
+ 'Linear_1': 'key_value',
+ 'Linear_2': 'out',
}[i]
else:
i = {
- 'Linear_0': 'query',
- 'Linear_1': 'key',
- 'Linear_2': 'value',
- 'Linear_3': 'out',
+ 'Linear_0': 'query',
+ 'Linear_1': 'key',
+ 'Linear_2': 'value',
+ 'Linear_3': 'out',
}[i]
else:
i = i.replace('Linear', 'Dense')
@@ -94,22 +96,25 @@ def key_transform(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py
index c152410b5..55c010b97 100644
--- a/tests/modeldiffs/imagenet_vit_glu/compare.py
+++ b/tests/modeldiffs/imagenet_vit_glu/compare.py
@@ -10,10 +10,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitGluWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetVitGluWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitGluWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitGluWorkload as PyTorchWorkload,
+)
sd_transform = None
@@ -30,22 +32,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py
index 7f1af41ab..17a7483c2 100644
--- a/tests/modeldiffs/imagenet_vit_map/compare.py
+++ b/tests/modeldiffs/imagenet_vit_map/compare.py
@@ -10,10 +10,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitMapWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetVitMapWorkload as PytWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitMapWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitMapWorkload as PytWorkload,
+)
def sd_transform(sd):
@@ -41,22 +43,25 @@ def sd_transform(sd):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py
index a3a639101..72d407a4b 100644
--- a/tests/modeldiffs/imagenet_vit_postln/compare.py
+++ b/tests/modeldiffs/imagenet_vit_postln/compare.py
@@ -10,10 +10,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitPostLNWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetViTPostLNWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitPostLNWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetViTPostLNWorkload as PyTorchWorkload,
+)
sd_transform = None
@@ -30,22 +32,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py
index 664b1242d..80aba62f2 100644
--- a/tests/modeldiffs/librispeech_conformer/compare.py
+++ b/tests/modeldiffs/librispeech_conformer/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py
index b0812e77d..a7bebf6bf 100644
--- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py
+++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py
index 3032a0005..d8f7980a2 100644
--- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py
+++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerGeluWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerGeluWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerGeluWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerGeluWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py
index d623ef352..7f4768c11 100644
--- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py
+++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerLayerNormWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerLayerNormWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerLayerNormWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerLayerNormWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py
index 84b0a6c86..bd1073524 100644
--- a/tests/modeldiffs/librispeech_deepspeech/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -56,9 +58,9 @@ def sd_transform(sd):
else:
out[k] = sd[k]
elif 'LSTM' in ''.join(k):
- l = out.get(k[:-1], dict())
- l[k[-1]] = sd[k]
- out[k[:-1]] = l
+ l_tmp = out.get(k[:-1], dict())
+ l_tmp[k[-1]] = sd[k]
+ out[k[:-1]] = l_tmp
else:
out[k] = sd[k]
keys_to_del = []
@@ -67,10 +69,12 @@ def sd_transform(sd):
if isinstance(out[k], dict):
kernels = ['kernel_ih_l0', 'kernel_hh_l0']
biases = ['bias_ih_l0', 'bias_hh_l0']
- weights = torch.cat([out[k][i].view(-1) for i in kernels] +
- [out[k][i + '_reverse'].view(-1) for i in kernels] +
- [out[k][i].view(-1) for i in biases] +
- [out[k][i + '_reverse'].view(-1) for i in biases])
+ weights = torch.cat(
+ [out[k][i].view(-1) for i in kernels]
+ + [out[k][i + '_reverse'].view(-1) for i in kernels]
+ + [out[k][i].view(-1) for i in biases]
+ + [out[k][i + '_reverse'].view(-1) for i in biases]
+ )
updates[k + ('weights',)] = weights
keys_to_del.append(k)
out.update(updates)
@@ -94,24 +98,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
index 2540c1b93..8593894e4 100644
--- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
@@ -7,13 +7,17 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechTanhWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechTanhWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
-from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
+from tests.modeldiffs.librispeech_deepspeech.compare import (
+ key_transform,
+ sd_transform,
+)
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
@@ -30,24 +34,27 @@
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
index e5972120d..27e4760a6 100644
--- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
@@ -7,13 +7,17 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
-from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
+from tests.modeldiffs.librispeech_deepspeech.compare import (
+ key_transform,
+ sd_transform,
+)
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
@@ -30,24 +34,27 @@
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
index 4d2c4a5d5..7990f063b 100644
--- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
@@ -7,13 +7,17 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
-from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
+from tests.modeldiffs.librispeech_deepspeech.compare import (
+ key_transform,
+ sd_transform,
+)
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
@@ -30,24 +34,27 @@
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py
index 5d5ef50bf..8c40d3c8a 100644
--- a/tests/modeldiffs/ogbg/compare.py
+++ b/tests/modeldiffs/ogbg/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py
index fc3992998..f35ed8b17 100644
--- a/tests/modeldiffs/ogbg_gelu/compare.py
+++ b/tests/modeldiffs/ogbg_gelu/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgGeluWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgGeluWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgGeluWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgGeluWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py
index e7cfa745c..0042c71af 100644
--- a/tests/modeldiffs/ogbg_model_size/compare.py
+++ b/tests/modeldiffs/ogbg_model_size/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgModelSizeWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgModelSizeWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgModelSizeWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgModelSizeWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py
index 4e3b96cf7..7583282cd 100644
--- a/tests/modeldiffs/ogbg_silu/compare.py
+++ b/tests/modeldiffs/ogbg_silu/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgSiluWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgSiluWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgSiluWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgSiluWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py
index d9264b400..4c95ca7e4 100644
--- a/tests/modeldiffs/torch2jax_utils.py
+++ b/tests/modeldiffs/torch2jax_utils.py
@@ -1,5 +1,5 @@
-from collections import Counter
import pprint
+from collections import Counter
def jax_like_pytorch_statedict(model, state_dict, keys=None):
@@ -32,13 +32,17 @@ def flatten(jm, ret, keys=None):
def value_transform(k, value, jax_value):
k_str = ''.join(k).lower()
- if ('conv' in k_str and 'kernel' in k_str) or \
- ('embedding' in k_str and 'kernel' in k_str):
+ if ('conv' in k_str and 'kernel' in k_str) or (
+ 'embedding' in k_str and 'kernel' in k_str
+ ):
if 'transpose' in k_str:
# Assumes 2D ConvTranspose with stride equal to kernel_size.
- return value.reshape(value.shape[0], value.shape[1],
- -1).flip(-1).permute(2, 0,
- 1).reshape(*jax_value.shape)
+ return (
+ value.reshape(value.shape[0], value.shape[1], -1)
+ .flip(-1)
+ .permute(2, 0, 1)
+ .reshape(*jax_value.shape)
+ )
else:
rank = len(value.shape)
if rank == 3:
@@ -51,16 +55,17 @@ def value_transform(k, value, jax_value):
value = value.t().reshape(*list(jax_value.shape))
elif 'attention' in k_str and 'bias' in k_str:
value = value.reshape(*list(jax_value.shape))
- elif ('dense' in k_str and 'kernel' in k_str) or \
- ('lstm' in k_str and 'kernel' in k_str) or \
- ('head' in k_str and 'kernel' in k_str) or \
- ('pre_logits' in k_str and 'kernel' in k_str):
+ elif (
+ ('dense' in k_str and 'kernel' in k_str)
+ or ('lstm' in k_str and 'kernel' in k_str)
+ or ('head' in k_str and 'kernel' in k_str)
+ or ('pre_logits' in k_str and 'kernel' in k_str)
+ ):
value = value.t()
return value
class Torch2Jax:
-
def __init__(self, torch_model, jax_model):
self.torch_model = torch_model
self.jax_model = jax_model
@@ -73,13 +78,13 @@ def __init__(self, torch_model, jax_model):
def key_transform(self, k_transform_fn):
self.pytorch_sd = {
- k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd
+ k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd
}
def value_transform(self, v_transform_fn):
self.pytorch_sd = {
- k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k])
- for k in self.pytorch_sd
+ k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k])
+ for k in self.pytorch_sd
}
def sd_transform(self, sd_transform_fn):
diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py
index 5595894e6..aa7bebd4f 100644
--- a/tests/modeldiffs/vanilla_sgd_jax.py
+++ b/tests/modeldiffs/vanilla_sgd_jax.py
@@ -1,28 +1,33 @@
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
+ update_params,
+)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Vanilla SGD Optimizer."""
del model_params
del model_state
del rng
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001)
optimizer_state = opt_init_fn(params_zeros_like)
diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py
index a6a0c5fa6..6448ac097 100644
--- a/tests/modeldiffs/vanilla_sgd_pytorch.py
+++ b/tests/modeldiffs/vanilla_sgd_pytorch.py
@@ -1,24 +1,29 @@
import torch
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
+ data_selection,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
+ update_params,
+)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Vanilla SGD Optimizer."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(model_params.parameters(), lr=0.001, weight_decay=0),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(), lr=0.001, weight_decay=0
+ ),
}
return optimizer_state
diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py
index 109bfa629..02175c8b5 100644
--- a/tests/modeldiffs/wmt/compare.py
+++ b/tests/modeldiffs/wmt/compare.py
@@ -8,17 +8,20 @@
from algoperf import spec
from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkload as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -60,7 +63,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -73,11 +76,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -112,29 +115,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py
index 1aa20fe3b..0c834dc86 100644
--- a/tests/modeldiffs/wmt_attention_temp/compare.py
+++ b/tests/modeldiffs/wmt_attention_temp/compare.py
@@ -7,19 +7,23 @@
import torch
from algoperf import spec
-from algoperf.workloads.wmt.wmt_jax.workload import \
- WmtWorkloadAttentionTemp as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkloadAttentionTemp as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkloadAttentionTemp as JaxWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkloadAttentionTemp as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -61,7 +65,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -74,11 +78,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -113,29 +117,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py
index e98a6945d..f7de12326 100644
--- a/tests/modeldiffs/wmt_glu_tanh/compare.py
+++ b/tests/modeldiffs/wmt_glu_tanh/compare.py
@@ -7,19 +7,23 @@
import torch
from algoperf import spec
-from algoperf.workloads.wmt.wmt_jax.workload import \
- WmtWorkloadGLUTanH as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkloadGLUTanH as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkloadGLUTanH as JaxWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkloadGLUTanH as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -61,7 +65,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -74,11 +78,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -113,29 +117,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py
index d110715b5..a8681ca8e 100644
--- a/tests/modeldiffs/wmt_post_ln/compare.py
+++ b/tests/modeldiffs/wmt_post_ln/compare.py
@@ -7,19 +7,23 @@
import torch
from algoperf import spec
-from algoperf.workloads.wmt.wmt_jax.workload import \
- WmtWorkloadPostLN as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkloadPostLN as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkloadPostLN as JaxWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkloadPostLN as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -61,7 +65,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -74,11 +78,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -113,29 +117,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py
index c4ca514a8..d17848aaf 100644
--- a/tests/reference_algorithm_tests.py
+++ b/tests/reference_algorithm_tests.py
@@ -27,67 +27,66 @@
import os
import pickle
-from absl import flags
-from absl import logging
-from absl.testing import absltest
import flax
-from flax import jax_utils
-from flax.core.frozen_dict import FrozenDict
import jax
-from jraph import GraphsTuple
import numpy as np
import tensorflow as tf
import torch
import torch.distributed as dist
+from absl import flags, logging
+from absl.testing import absltest
+from flax import jax_utils
+from flax.core.frozen_dict import FrozenDict
+from jraph import GraphsTuple
-from algoperf import halton
-from algoperf import pytorch_utils
+import submission_runner
+from algoperf import halton, pytorch_utils
from algoperf import random_utils as prng
from algoperf.profiler import PassThroughProfiler
from algoperf.workloads import workloads
from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline
from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map
-import submission_runner
from tests.modeldiffs import diff as diff_utils
flags.DEFINE_integer(
- 'global_batch_size',
- -1,
- ('Global Batch size to use when running an individual workload. Otherwise '
- 'a per-device batch size of 2 is used.'))
+ 'global_batch_size',
+ -1,
+ (
+ 'Global Batch size to use when running an individual workload. Otherwise '
+ 'a per-device batch size of 2 is used.'
+ ),
+)
flags.DEFINE_integer('num_train_steps', 1, 'Number of steps to train.')
flags.DEFINE_boolean('use_fake_input_queue', True, 'Use fake data examples.')
flags.DEFINE_string('log_file', '/tmp/log.pkl', 'The log file')
flags.DEFINE_boolean(
- 'all',
- False,
- 'Run all workloads instead of using --workload and --framework.')
-flags.DEFINE_boolean('identical',
- False,
- 'Run jax and pytorch with identical weights.')
+ 'all', False, 'Run all workloads instead of using --workload and --framework.'
+)
+flags.DEFINE_boolean(
+ 'identical', False, 'Run jax and pytorch with identical weights.'
+)
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, PYTORCH_DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
tf.config.set_visible_devices([], 'GPU')
_EXPECTED_METRIC_NAMES = {
- 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'],
- 'criteo1tb': ['train/loss', 'validation/loss'],
- 'criteo1tb_test': ['train/loss', 'validation/loss'],
- 'fastmri': ['train/ssim', 'validation/ssim'],
- 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'],
- 'imagenet_vit': ['train/accuracy', 'validation/accuracy'],
- 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'],
- 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'],
- 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'],
- 'ogbg': [
- 'train/accuracy', 'validation/loss', 'test/mean_average_precision'
- ],
- 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'],
+ 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'],
+ 'criteo1tb': ['train/loss', 'validation/loss'],
+ 'criteo1tb_test': ['train/loss', 'validation/loss'],
+ 'fastmri': ['train/ssim', 'validation/ssim'],
+ 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'],
+ 'imagenet_vit': ['train/accuracy', 'validation/accuracy'],
+ 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'],
+ 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'],
+ 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'],
+ 'ogbg': ['train/accuracy', 'validation/loss', 'test/mean_average_precision'],
+ 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'],
}
def _make_fake_image_batch(batch_shape, data_shape, num_classes):
- examples = np.random.normal(size=(*batch_shape,
- *data_shape)).astype(np.float32)
+ examples = np.random.normal(size=(*batch_shape, *data_shape)).astype(
+ np.float32
+ )
labels = np.random.randint(0, num_classes, size=batch_shape)
masks = np.ones(batch_shape, dtype=np.float32)
return {'inputs': examples, 'targets': labels, 'weights': masks}
@@ -96,16 +95,17 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes):
def _pytorch_map(inputs):
if USE_PYTORCH_DDP:
return jax.tree.map(
- lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs)
+ lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs
+ )
return jax.tree.map(
- lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1])
- if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view(
- -1),
- inputs)
+ lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1])
+ if len(a.shape) == 3
+ else torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1),
+ inputs,
+ )
class _FakeTokenizer:
-
def detokenize(self, *args):
del args
return tf.constant('this is a fake sequence?')
@@ -113,15 +113,14 @@ def detokenize(self, *args):
@flax.struct.dataclass
class _FakeMetricsCollection:
-
def merge(self, *args):
del args
return self
def compute(self):
return {
- 'wer': 0.0,
- 'ctc_loss': 0.0,
+ 'wer': 0.0,
+ 'ctc_loss': 0.0,
}
def unreplicate(self):
@@ -129,7 +128,6 @@ def unreplicate(self):
class _FakeMetricsLogger:
-
def __init__(self):
self.filename = FLAGS.log_file
self.scalars = []
@@ -152,27 +150,27 @@ def append_eval_metrics(self, result):
def save(self):
with open(self.filename, 'wb') as f:
- pickle.dump({'scalars': self.scalars, 'eval_results': self.eval_results},
- f)
+ pickle.dump(
+ {'scalars': self.scalars, 'eval_results': self.eval_results}, f
+ )
class _FakeMetricsBundle:
-
def gather_from_model_output(self, *args, **kwargs):
del args
del kwargs
return _FakeMetricsCollection()
-def _make_one_batch_workload(workload_class,
- workload_name,
- framework,
- global_batch_size,
- use_fake_input_queue,
- n_gpus):
-
+def _make_one_batch_workload(
+ workload_class,
+ workload_name,
+ framework,
+ global_batch_size,
+ use_fake_input_queue,
+ n_gpus,
+):
class _OneEvalBatchWorkload(workload_class):
-
def __init__(self):
kwargs = {}
if 'librispeech' in workload_name:
@@ -184,24 +182,30 @@ def __init__(self):
if 'librispeech' in workload_name:
self.tokenizer = _FakeTokenizer()
- def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None):
+ def init_model_fn(self, rng):
# pylint: disable=line-too-long
- if not (FLAGS.identical and
- os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')):
- return super().init_model_fn(
- rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate)
+ if not (
+ FLAGS.identical
+ and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')
+ ):
+ return super().init_model_fn(rng)
if framework == 'jax':
compare_module = importlib.import_module(
- f'tests.modeldiffs.{workload_name}.compare')
+ f'tests.modeldiffs.{workload_name}.compare'
+ )
jax_params, model_state, _ = diff_utils.torch2jax(
jax_workload=super(),
pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs),
key_transform=compare_module.key_transform,
- sd_transform=compare_module.sd_transform)
- return (FrozenDict(**jax_utils.replicate(jax_params)),
- FrozenDict(**jax_utils.replicate(model_state))
- if model_state is not None else model_state)
- return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0)
+ sd_transform=compare_module.sd_transform,
+ )
+ return (
+ FrozenDict(**jax_utils.replicate(jax_params)),
+ FrozenDict(**jax_utils.replicate(model_state))
+ if model_state is not None
+ else model_state,
+ )
+ return super().init_model_fn([0])
@property
def num_eval_train_examples(self):
@@ -236,73 +240,74 @@ def _build_input_queue(self, *args, **kwargs):
else:
data_shape = (3, 32, 32)
fake_batch = _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=10)
+ batch_shape, data_shape=data_shape, num_classes=10
+ )
elif workload_name == 'criteo1tb' or workload_name == 'criteo1tb_test':
targets = np.ones(batch_shape)
targets[0] = 0
fake_batch = {
- 'inputs': np.ones((*batch_shape, 13 + 26)),
- 'targets': targets,
- 'weights': np.ones(batch_shape),
+ 'inputs': np.ones((*batch_shape, 13 + 26)),
+ 'targets': targets,
+ 'weights': np.ones(batch_shape),
}
elif workload_name in ['imagenet_resnet', 'imagenet_vit']:
data_shape = (224, 224, 3)
fake_batch = _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=1000)
+ batch_shape, data_shape=data_shape, num_classes=1000
+ )
if framework == 'pytorch':
num_dims = len(fake_batch['inputs'].shape)
fake_batch['inputs'] = fake_batch['inputs'].transpose(
- (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2))
+ (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2)
+ )
elif 'librispeech' in workload_name:
rate = 16000
- l = None
- while l is None or l.shape[-1] < 320000:
+ audio_signal = None
+ while audio_signal is None or audio_signal.shape[-1] < 320000:
duration = 0.5
- freq = 2**(np.random.rand(*batch_shape, 1) * 13)
+ freq = 2 ** (np.random.rand(*batch_shape, 1) * 13)
wav = np.sin(2 * np.pi * freq * np.arange(rate * duration) / rate)
- if l is None:
- l = wav
+ if audio_signal is None:
+ audio_signal = wav
else:
- l = np.concatenate([l, wav], axis=-1)
- inputs = l
+ audio_signal = np.concatenate([audio_signal, wav], axis=-1)
+ inputs = audio_signal
targets = np.random.randint(low=1, high=1024, size=(*batch_shape, 256))
tgt_pad = np.arange(0, 256)[tuple([None] * len(batch_shape))]
tgt_lengths = np.random.randint(
- low=100, high=256, size=(*batch_shape, 1))
+ low=100, high=256, size=(*batch_shape, 1)
+ )
tgt_pad = 1 * (tgt_pad > tgt_lengths)
fake_batch = {
- 'inputs': (inputs, np.zeros_like(inputs)),
- 'targets': (targets, tgt_pad),
+ 'inputs': (inputs, np.zeros_like(inputs)),
+ 'targets': (targets, tgt_pad),
}
elif workload_name == 'mnist':
fake_batch = _make_fake_image_batch(
- batch_shape, data_shape=(28, 28, 1), num_classes=10)
+ batch_shape, data_shape=(28, 28, 1), num_classes=10
+ )
elif workload_name == 'ogbg':
tf.random.set_seed(5)
def _fake_iter():
while True:
fake_batch = {
- 'num_nodes':
- tf.ones((1,), dtype=tf.int64),
- 'edge_index':
- tf.ones((1, 2), dtype=tf.int64),
- 'node_feat':
- tf.random.normal((1, 9)),
- 'edge_feat':
- tf.random.normal((1, 3)),
- 'labels':
- tf.cast(
- tf.random.uniform((self._num_outputs,),
- minval=0,
- maxval=2,
- dtype=tf.int32),
- tf.float32),
+ 'num_nodes': tf.ones((1,), dtype=tf.int64),
+ 'edge_index': tf.ones((1, 2), dtype=tf.int64),
+ 'node_feat': tf.random.normal((1, 9)),
+ 'edge_feat': tf.random.normal((1, 3)),
+ 'labels': tf.cast(
+ tf.random.uniform(
+ (self._num_outputs,), minval=0, maxval=2, dtype=tf.int32
+ ),
+ tf.float32,
+ ),
}
yield fake_batch
fake_batch_iter = ogbg_input_pipeline._get_batch_iterator(
- _fake_iter(), global_batch_size)
+ _fake_iter(), global_batch_size
+ )
fake_batch = next(fake_batch_iter) # pylint: disable=stop-iteration-return
if framework == 'pytorch':
fake_batch['inputs'] = _graph_map(_pytorch_map, fake_batch['inputs'])
@@ -311,48 +316,49 @@ def _fake_iter():
elif workload_name == 'wmt':
max_len = 256
fake_batch = {
- 'inputs':
- np.random.randint(
- low=0, high=32000, size=(*batch_shape, max_len)),
- 'targets':
- np.random.randint(
- low=0, high=32000, size=(*batch_shape, max_len)),
- 'weights':
- np.random.randint(low=0, high=2, size=(*batch_shape, max_len)),
+ 'inputs': np.random.randint(
+ low=0, high=32000, size=(*batch_shape, max_len)
+ ),
+ 'targets': np.random.randint(
+ low=0, high=32000, size=(*batch_shape, max_len)
+ ),
+ 'weights': np.random.randint(
+ low=0, high=2, size=(*batch_shape, max_len)
+ ),
}
self._tokenizer = _FakeTokenizer()
elif workload_name == 'fastmri':
data_shape = (320, 320)
fake_batch = {
- 'inputs':
- _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=1000)
- ['inputs'],
- 'targets':
- _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=1000)
- ['inputs'],
- 'mean':
- np.zeros(batch_shape),
- 'std':
- np.ones(batch_shape),
- 'volume_max':
- np.zeros(batch_shape),
- 'weights':
- np.ones(batch_shape),
+ 'inputs': _make_fake_image_batch(
+ batch_shape, data_shape=data_shape, num_classes=1000
+ )['inputs'],
+ 'targets': _make_fake_image_batch(
+ batch_shape, data_shape=data_shape, num_classes=1000
+ )['inputs'],
+ 'mean': np.zeros(batch_shape),
+ 'std': np.ones(batch_shape),
+ 'volume_max': np.zeros(batch_shape),
+ 'weights': np.ones(batch_shape),
}
else:
raise ValueError(
- 'Workload {} does not have a fake batch defined, you '
- 'can add it or use --use_fake_input_queue=false.'.format(
- workload_name))
+ 'Workload {} does not have a fake batch defined, you '
+ 'can add it or use --use_fake_input_queue=false.'.format(
+ workload_name
+ )
+ )
if framework == 'pytorch':
def to_device(k, v):
dtype = (
- torch.long if (k == 'targets' and workload_name != 'fastmri') else
- torch.bool if k == 'weights' else torch.float)
+ torch.long
+ if (k == 'targets' and workload_name != 'fastmri')
+ else torch.bool
+ if k == 'weights'
+ else torch.float
+ )
if USE_PYTORCH_DDP:
v = v[RANK]
return torch.as_tensor(v, device=PYTORCH_DEVICE, dtype=dtype)
@@ -388,24 +394,28 @@ def eval_model(self, *args, **kwargs):
return _OneEvalBatchWorkload()
-def _test_submission(workload_name,
- framework,
- submission_path,
- search_space_path,
- data_dir,
- use_fake_input_queue,
- n_gpus):
+def _test_submission(
+ workload_name,
+ framework,
+ submission_path,
+ search_space_path,
+ data_dir,
+ use_fake_input_queue,
+ n_gpus,
+):
logging.info(f'========= Testing {workload_name} in {framework}.')
FLAGS.framework = framework
workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload_name])
workload_metadata['workload_path'] = os.path.join(
- submission_runner.BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + '_' + framework,
- 'workload.py')
+ submission_runner.BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + '_' + framework,
+ 'workload.py',
+ )
workload_class = workloads.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- return_class=True)
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ return_class=True,
+ )
print(f'Workload class for {workload_name} is {workload_class}')
submission_module_path = workloads.convert_filepath_to_module(submission_path)
@@ -422,30 +432,32 @@ def _test_submission(workload_name,
global_batch_size = FLAGS.global_batch_size
if FLAGS.global_batch_size < 0:
raise ValueError('Must set --global_batch_size.')
- workload = _make_one_batch_workload(workload_class,
- workload_name,
- framework,
- global_batch_size,
- use_fake_input_queue,
- n_gpus)
+ workload = _make_one_batch_workload(
+ workload_class,
+ workload_name,
+ framework,
+ global_batch_size,
+ use_fake_input_queue,
+ n_gpus,
+ )
# Get a sample hyperparameter setting.
hyperparameters = {}
if search_space_path != 'None':
with open(search_space_path, 'r', encoding='UTF-8') as search_space_file:
hyperparameters = halton.generate_search(
- json.load(search_space_file), num_trials=1)[0]
+ json.load(search_space_file), num_trials=1
+ )[0]
rng = prng.PRNGKey(0)
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
input_queue = workload._build_input_queue(
- data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size)
+ data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size
+ )
model_params, model_state = workload.init_model_fn(model_init_rng)
- optimizer_state = init_optimizer_state(workload,
- model_params,
- model_state,
- hyperparameters,
- opt_init_rng)
+ optimizer_state = init_optimizer_state(
+ workload, model_params, model_state, hyperparameters, opt_init_rng
+ )
if USE_PYTORCH_DDP:
torch.cuda.empty_cache()
@@ -453,44 +465,49 @@ def _test_submission(workload_name,
for global_step in range(FLAGS.num_train_steps):
step_rng = prng.fold_in(rng, global_step)
data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
- batch = data_selection(workload,
- input_queue,
- optimizer_state,
- model_params,
- model_state,
- hyperparameters,
- global_step,
- data_select_rng)
+ batch = data_selection(
+ workload,
+ input_queue,
+ optimizer_state,
+ model_params,
+ model_state,
+ hyperparameters,
+ global_step,
+ data_select_rng,
+ )
optimizer_state, model_params, model_state = update_params(
- workload=workload,
- current_param_container=model_params,
- current_params_types=workload.model_params_types,
- model_state=model_state,
- hyperparameters=hyperparameters,
- batch=batch,
- loss_type=workload.loss_type,
- optimizer_state=optimizer_state,
- train_state={},
- eval_results=[],
- global_step=global_step,
- rng=update_rng)
+ workload=workload,
+ current_param_container=model_params,
+ current_params_types=workload.model_params_types,
+ model_state=model_state,
+ hyperparameters=hyperparameters,
+ batch=batch,
+ loss_type=workload.loss_type,
+ optimizer_state=optimizer_state,
+ train_state={},
+ eval_results=[],
+ global_step=global_step,
+ rng=update_rng,
+ )
eval_result = workload.eval_model(
- global_batch_size,
- model_params,
- model_state,
- eval_rng,
- data_dir,
- imagenet_v2_data_dir=None,
- global_step=global_step)
- _ = workload.eval_model(
global_batch_size,
model_params,
model_state,
eval_rng,
data_dir,
imagenet_v2_data_dir=None,
- global_step=global_step)
+ global_step=global_step,
+ )
+ _ = workload.eval_model(
+ global_batch_size,
+ model_params,
+ model_state,
+ eval_rng,
+ data_dir,
+ imagenet_v2_data_dir=None,
+ global_step=global_step,
+ )
return eval_result
@@ -500,12 +517,15 @@ def _make_paths(repo_location, framework, workload_name):
else:
dataset_name = workload_name
workload_dir = (
- f'{repo_location}/reference_algorithms/target_setting_algorithms/'
- f'{workload_name}')
+ f'{repo_location}/reference_algorithms/target_setting_algorithms/'
+ f'{workload_name}'
+ )
search_space_path = f'{workload_dir}/tuning_search_space.json'
- submission_path = (f'reference_algorithms/target_setting_algorithms/'
- f'{workload_name}/{dataset_name}_{framework}/'
- 'submission.py')
+ submission_path = (
+ f'reference_algorithms/target_setting_algorithms/'
+ f'{workload_name}/{dataset_name}_{framework}/'
+ 'submission.py'
+ )
full_submission_path = f'{repo_location}/{submission_path}'
if not os.path.exists(full_submission_path):
return None, None
@@ -535,7 +555,8 @@ def test_submission(self):
if FLAGS.tuning_search_space:
raise ValueError('Cannot set --tuning_search_space and --all.')
references_dir = (
- f'{repo_location}/reference_algorithms/target_setting_algorithms')
+ f'{repo_location}/reference_algorithms/target_setting_algorithms'
+ )
for workload_name in os.listdir(references_dir):
for framework in ['jax', 'pytorch']:
if framework == 'pytorch':
@@ -543,17 +564,19 @@ def test_submission(self):
# First jax operation has to be called after pytorch_init.
n_gpus = max(N_GPUS, jax.local_device_count())
search_space_path, submission_path = _make_paths(
- repo_location, framework, workload_name)
+ repo_location, framework, workload_name
+ )
if search_space_path is None:
continue
eval_result = _test_submission(
- workload_name,
- framework,
- submission_path,
- search_space_path,
- data_dir=FLAGS.data_dir,
- use_fake_input_queue=FLAGS.use_fake_input_queue,
- n_gpus=n_gpus)
+ workload_name,
+ framework,
+ submission_path,
+ search_space_path,
+ data_dir=FLAGS.data_dir,
+ use_fake_input_queue=FLAGS.use_fake_input_queue,
+ n_gpus=n_gpus,
+ )
self._assert_eval_result(workload_name, eval_result)
else:
framework = FLAGS.framework
@@ -567,15 +590,17 @@ def test_submission(self):
submission_path = FLAGS.submission_path
else:
search_space_path, submission_path = _make_paths(
- repo_location, framework, workload_name)
+ repo_location, framework, workload_name
+ )
eval_result = _test_submission(
- workload_name,
- framework,
- submission_path,
- search_space_path,
- data_dir=FLAGS.data_dir,
- use_fake_input_queue=FLAGS.use_fake_input_queue,
- n_gpus=n_gpus)
+ workload_name,
+ framework,
+ submission_path,
+ search_space_path,
+ data_dir=FLAGS.data_dir,
+ use_fake_input_queue=FLAGS.use_fake_input_queue,
+ n_gpus=n_gpus,
+ )
self._assert_eval_result(workload_name, eval_result)
if USE_PYTORCH_DDP:
diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py
index ff724b201..c6c993b7b 100644
--- a/tests/submission_runner_test.py
+++ b/tests/submission_runner_test.py
@@ -4,17 +4,16 @@
dataset to be available. For testing the workload and reference submission code
for all workloads, see reference_algorithm_tests.py.
"""
+
import copy
import os
import sys
-from absl import flags
-from absl import logging
-from absl.testing import absltest
-from absl.testing import parameterized
+from absl import flags, logging
+from absl.testing import absltest, parameterized
-from algoperf.profiler import PassThroughProfiler
import submission_runner
+from algoperf.profiler import PassThroughProfiler
FLAGS = flags.FLAGS
# Needed to avoid UnparsedFlagAccessError
@@ -28,47 +27,46 @@ class SubmissionRunnerTest(parameterized.TestCase):
"""Tests for reference submissions."""
@parameterized.named_parameters(
- dict(
- testcase_name='mnist_jax',
- workload='mnist',
- framework='jax',
- submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'),
- tuning_search_space=(
- f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')),
- dict(
- testcase_name='mnist_pytorch',
- workload='mnist',
- framework='pytorch',
- submission_path=(
- f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'),
- tuning_search_space=(
- f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')),
+ dict(
+ testcase_name='mnist_jax',
+ workload='mnist',
+ framework='jax',
+ submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'),
+ tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'),
+ ),
+ dict(
+ testcase_name='mnist_pytorch',
+ workload='mnist',
+ framework='pytorch',
+ submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'),
+ tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'),
+ ),
)
- def test_submission(self,
- workload,
- framework,
- submission_path,
- tuning_search_space):
+ def test_submission(
+ self, workload, framework, submission_path, tuning_search_space
+ ):
FLAGS.framework = framework
workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload])
workload_metadata['workload_path'] = os.path.join(
- submission_runner.BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + '_' + framework,
- 'workload.py')
+ submission_runner.BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + '_' + framework,
+ 'workload.py',
+ )
workload_obj = submission_runner.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs={})
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs={},
+ )
score = submission_runner.score_submission_on_workload(
- workload_obj,
- workload,
- submission_path,
- data_dir='~/tensorflow_datasets', # The default in TFDS.
- tuning_ruleset='external',
- tuning_search_space=tuning_search_space,
- num_tuning_trials=1,
- profiler=PassThroughProfiler(),
- max_global_steps=500,
+ workload_obj,
+ workload,
+ submission_path,
+ data_dir='~/tensorflow_datasets', # The default in TFDS.
+ tuning_ruleset='external',
+ tuning_search_space=tuning_search_space,
+ num_tuning_trials=1,
+ profiler=PassThroughProfiler(),
+ max_global_steps=500,
)
logging.info(score)
diff --git a/tests/test_baselines.py b/tests/test_baselines.py
index b2be8aa11..c5097a567 100644
--- a/tests/test_baselines.py
+++ b/tests/test_baselines.py
@@ -1,20 +1,19 @@
"""Tests for submission.py for baselines.
-This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that
+This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that
requires the dataset to be available.
"""
+
import copy
import os
import sys
-from absl import flags
-from absl import logging
-from absl.testing import absltest
-from absl.testing import parameterized
+from absl import flags, logging
+from absl.testing import absltest, parameterized
+import submission_runner
from algoperf.profiler import PassThroughProfiler
from algoperf.workloads import workloads
-import submission_runner
FLAGS = flags.FLAGS
# Needed to avoid UnparsedFlagAccessError
@@ -24,41 +23,42 @@
MAX_GLOBAL_STEPS = 5
baselines = {
- 'jax': [
- 'adafactor',
- 'adamw',
- 'lamb',
- 'momentum',
- 'nadamw',
- 'nesterov',
- 'sam',
- 'shampoo',
- ],
- 'pytorch': [
- 'adamw',
- 'momentum',
- 'nadamw',
- 'nesterov',
- ],
+ 'jax': [
+ 'adafactor',
+ 'adamw',
+ 'lamb',
+ 'momentum',
+ 'nadamw',
+ 'nesterov',
+ 'sam',
+ 'shampoo',
+ ],
+ 'pytorch': [
+ 'adamw',
+ 'momentum',
+ 'nadamw',
+ 'nesterov',
+ ],
}
frameworks = [
- 'pytorch',
- 'jax',
+ 'pytorch',
+ 'jax',
]
-baseline_path = "reference_algorithms/paper_baselines"
+baseline_path = 'reference_algorithms/paper_baselines'
named_parameters = []
for f in frameworks:
for b in baselines[f]:
named_parameters.append(
- dict(
- testcase_name=f'{b}_{f}',
- workload='mnist',
- framework=f'{f}',
- submission_path=f'{baseline_path}/{b}/{f}/submission.py',
- tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json')
+ dict(
+ testcase_name=f'{b}_{f}',
+ workload='mnist',
+ framework=f'{f}',
+ submission_path=f'{baseline_path}/{b}/{f}/submission.py',
+ tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json',
+ )
)
@@ -66,31 +66,31 @@ class BaselineTest(parameterized.TestCase):
"""Tests for reference submissions."""
@parameterized.named_parameters(*named_parameters)
- def test_baseline_submission(self,
- workload,
- framework,
- submission_path,
- tuning_search_space):
+ def test_baseline_submission(
+ self, workload, framework, submission_path, tuning_search_space
+ ):
FLAGS.framework = framework
workload_metadata = copy.deepcopy(workloads.WORKLOADS[workload])
workload_metadata['workload_path'] = os.path.join(
- workloads.BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + '_' + framework,
- 'workload.py')
+ workloads.BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + '_' + framework,
+ 'workload.py',
+ )
workload_obj = workloads.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs={})
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs={},
+ )
score = submission_runner.score_submission_on_workload(
- workload_obj,
- workload,
- submission_path,
- data_dir='~/tensorflow_datasets', # The default in TFDS.
- tuning_ruleset='external',
- tuning_search_space=tuning_search_space,
- num_tuning_trials=1,
- profiler=PassThroughProfiler(),
- max_global_steps=MAX_GLOBAL_STEPS,
+ workload_obj,
+ workload,
+ submission_path,
+ data_dir='~/tensorflow_datasets', # The default in TFDS.
+ tuning_ruleset='external',
+ tuning_search_space=tuning_search_space,
+ num_tuning_trials=1,
+ profiler=PassThroughProfiler(),
+ max_global_steps=MAX_GLOBAL_STEPS,
)
logging.info(score)
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
new file mode 100644
index 000000000..28e506400
--- /dev/null
+++ b/tests/test_jax_utils.py
@@ -0,0 +1,215 @@
+"""
+Test algoperf.jax_utils.Dropout by comparing to flax.linen.Dropout
+Run it as: pytest
+"""
+
+from functools import partial
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from absl.testing import absltest, parameterized
+from jax.tree_util import tree_leaves, tree_map, tree_structure
+
+from algoperf.jax_utils import Dropout
+
+SEED = 1996
+DEFAULT_DROPOUT = 0.5
+
+
+def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
+ """
+ A custom function to check if two PyTrees are equal, handling floats with
+ a tolerance.
+ """
+ if tree_structure(a) != tree_structure(b):
+ return False
+
+ def leaf_comparator(x, y):
+ # Use allclose for floating-point JAX arrays
+ if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
+ return jnp.allclose(x, y, rtol=rtol, atol=atol)
+ # Use standard equality for everything else
+ else:
+ return x == y
+
+ comparison_tree = tree_map(leaf_comparator, a, b)
+ all_equal = all(tree_leaves(comparison_tree))
+
+ return all_equal
+
+
+class LegacyDropoutModel(nn.Module):
+ dropout_rate: float = DEFAULT_DROPOUT
+
+ @nn.compact
+ def __call__(self, x, train):
+ return nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
+
+
+class DropoutModel(nn.Module):
+ @nn.compact
+ def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT):
+ return Dropout(rate=dropout_rate, deterministic=not train)(
+ x, rate=dropout_rate
+ )
+
+
+class DropoutTest(parameterized.TestCase):
+ @parameterized.named_parameters(
+ dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'),
+ dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'),
+ dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'),
+ dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'),
+ )
+ def test_forward(self, dropout_rate, mode):
+ """Compare forward pass of Dropout layer to flax.linen.Dropout in train and
+ eval mode.
+ """
+
+ # initialize models
+ rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
+ fake_batch = jnp.ones((10,))
+ orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
+ cust_model = DropoutModel()
+
+ initial_variables_original = orig_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+ initial_variables_custom = cust_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+
+ assert pytrees_are_equal(
+ initial_variables_original, initial_variables_custom, rtol=1e-6
+ )
+
+ # forward pass
+ x = jnp.ones((10,))
+
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_variables_original, x, train=train, rngs={'dropout': dropout_rng}
+ )
+ y2 = cust_model.apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
+ )
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'),
+ dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'),
+ dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'),
+ dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'),
+ )
+ def test_dropout_update(self, dropout_rate, mode):
+ """Call forward pass of Dropout layer with two different dropout rates
+ and check that the output matches to flax.linen.Dropout in train and
+ eval mode.
+ """
+ # init model
+ rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
+ fake_batch = jnp.ones((10,))
+ orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
+ cust_model = DropoutModel()
+
+ initial_variables_original = orig_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+
+ initial_variables_custom = cust_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+
+ assert pytrees_are_equal(
+ initial_variables_original, initial_variables_custom, rtol=1e-6
+ )
+
+ # forward pass
+ x = jnp.ones((10,))
+
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_variables_original, x, train=train, rngs={'dropout': dropout_rng}
+ )
+
+ _ = cust_model.apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=0.9,
+ rngs={'dropout': dropout_rng},
+ )
+
+ y2 = cust_model.apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
+ )
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'),
+ dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'),
+ dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'),
+ dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'),
+ )
+ def test_jitted_updates(self, dropout_rate, mode):
+ """Compare jitted updates with dropout."""
+
+ # initialize models
+ rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
+ fake_batch = jnp.ones((10,))
+ orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
+ cust_model = DropoutModel()
+
+ initial_variables_original = orig_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+ initial_variables_custom = cust_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+
+ assert pytrees_are_equal(
+ initial_variables_original, initial_variables_custom, rtol=1e-6
+ )
+
+ # forward pass
+ x = jnp.ones((10,))
+
+ train = mode == 'train'
+ jitted_original_apply = jax.jit(
+ partial(orig_model.apply), static_argnames=['train']
+ )
+ jitted_custom_apply = jax.jit(
+ partial(cust_model.apply), static_argnames=['train']
+ )
+
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y1 = jitted_original_apply(
+ initial_variables_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng},
+ )
+
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y2 = jitted_custom_apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=d,
+ rngs={'dropout': dropout_rng},
+ )
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/test_num_params.py b/tests/test_num_params.py
index b0633025e..9361f4c72 100644
--- a/tests/test_num_params.py
+++ b/tests/test_num_params.py
@@ -5,48 +5,59 @@
import pytest
import torch
-from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \
- DlrmSmall as JaxDlrmSmall
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \
- DlrmSmall as PyTorchDlrmSmall
-from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \
- ResNet18 as JaxResNet_c10
-from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \
- ResNet50 as JaxResNet
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- resnet18 as PyTorchResNet_c10
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- resnet50 as PyTorchResNet
+from algoperf.workloads.criteo1tb.criteo1tb_jax.models import (
+ DlrmSmall as JaxDlrmSmall,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import (
+ DlrmSmall as PyTorchDlrmSmall,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.models import (
+ ResNet18 as JaxResNet_c10,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.models import (
+ ResNet50 as JaxResNet,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ resnet18 as PyTorchResNet_c10,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ resnet50 as PyTorchResNet,
+)
from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
- ViT as PyTorchViT
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \
- Conformer as JaxConformer
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \
- ConformerConfig as JaxConformerConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerConfig as PytorchConformerConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerEncoderDecoder as PytorchConformer
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import (
+ ViT as PyTorchViT,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models import (
+ Conformer as JaxConformer,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models import (
+ ConformerConfig as JaxConformerConfig,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ ConformerConfig as PytorchConformerConfig,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ ConformerEncoderDecoder as PytorchConformer,
+)
from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP
-from algoperf.workloads.mnist.mnist_pytorch.workload import \
- _Model as PyTorchMLP
+from algoperf.workloads.mnist.mnist_pytorch.workload import _Model as PyTorchMLP
from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN
from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer
from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig
-from algoperf.workloads.wmt.wmt_pytorch.models import \
- Transformer as PyTorchTransformer
+from algoperf.workloads.wmt.wmt_pytorch.models import (
+ Transformer as PyTorchTransformer,
+)
WORKLOADS = [
- 'mnist',
- 'cifar',
- 'criteo1tb',
- 'imagenet_resnet',
- 'imagenet_vit',
- 'wmt',
- 'ogbg',
- 'librispeech_conformer',
+ 'mnist',
+ 'cifar',
+ 'criteo1tb',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'wmt',
+ 'ogbg',
+ 'librispeech_conformer',
]
@@ -56,7 +67,8 @@ def test_matching_num_params(workload):
# Count parameters of both models.
num_jax_params = sum(x.size for x in jax.tree_util.tree_leaves(jax_model))
num_pytorch_params = sum(
- p.numel() for p in pytorch_model.parameters() if p.requires_grad)
+ p.numel() for p in pytorch_model.parameters() if p.requires_grad
+ )
assert num_jax_params == num_pytorch_params
@@ -72,8 +84,9 @@ def get_models(workload):
# Init Jax model.
input_shape = (1, 32, 32, 3)
model_init = jax.jit(JaxResNet_c10(num_classes=10, dtype=jnp.float32).init)
- jax_model = model_init(init_rngs, jnp.ones(input_shape,
- jnp.float32))["params"]
+ jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32))[
+ 'params'
+ ]
# Init PyTorch model.
pytorch_model = PyTorchResNet_c10(num_classes=10)
@@ -85,35 +98,38 @@ def get_models(workload):
vocab_size = 32 * 128 * 1024
input_shape = (1, 39)
model_init = JaxDlrmSmall(
- vocab_size=vocab_size,
- num_dense_features=13,
- mlp_bottom_dims=mlp_bottom_dims,
- mlp_top_dims=mlp_top_dims,
- embed_dim=embed_dim).init
- jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32),
- False)['params']
+ vocab_size=vocab_size,
+ num_dense_features=13,
+ mlp_bottom_dims=mlp_bottom_dims,
+ mlp_top_dims=mlp_top_dims,
+ embed_dim=embed_dim,
+ ).init
+ jax_model = model_init(
+ init_rngs, jnp.ones(input_shape, jnp.float32), False
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchDlrmSmall(
- vocab_size=vocab_size,
- num_dense_features=13,
- mlp_bottom_dims=mlp_bottom_dims,
- mlp_top_dims=mlp_top_dims,
- embed_dim=embed_dim)
+ vocab_size=vocab_size,
+ num_dense_features=13,
+ mlp_bottom_dims=mlp_bottom_dims,
+ mlp_top_dims=mlp_top_dims,
+ embed_dim=embed_dim,
+ )
elif workload == 'imagenet_resnet':
# Init Jax model.
input_shape = (1, 224, 224, 3)
- jax_model = JaxResNet(
- num_classes=1000,
- dtype=jnp.float32).init(init_rngs, jnp.ones(input_shape,
- jnp.float32))['params']
+ jax_model = JaxResNet(num_classes=1000, dtype=jnp.float32).init(
+ init_rngs, jnp.ones(input_shape, jnp.float32)
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchResNet()
elif workload == 'imagenet_vit':
# Init Jax model.
input_shape = (1, 224, 224, 3)
jax_model = JaxViT(num_classes=1000).init(
- init_rngs, jnp.ones(input_shape, jnp.float32))['params']
+ init_rngs, jnp.ones(input_shape, jnp.float32)
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchViT()
elif workload == 'librispeech_conformer':
@@ -123,8 +139,9 @@ def get_models(workload):
# Init Jax model
input_shape = [(320000,), (320000,)]
fake_input_batch = [jnp.zeros((2, *x), jnp.float32) for x in input_shape]
- jax_model = jax_model.init(
- init_rngs, train=False, *fake_input_batch)["params"]
+ jax_model = jax_model.init(init_rngs, train=False, *fake_input_batch)[
+ 'params'
+ ]
# Run model once to initialize lazy layers
wave = torch.randn(2, 320000)
@@ -136,23 +153,26 @@ def get_models(workload):
input_shape = (16, 256)
target_shape = (16, 256)
jax_model = JaxTransformer(TransformerConfig).init(
- init_rngs,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))['params']
+ init_rngs,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchTransformer()
elif workload == 'ogbg':
# Init Jax model.
fake_batch = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, 128)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, 128)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]),
+ )
jax_model = JaxGNN(num_outputs=128).init(
- init_rngs, fake_batch, train=False)['params']
+ init_rngs, fake_batch, train=False
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchGNN(num_outputs=128)
else:
diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py
index df4c798d8..2243ce52e 100644
--- a/tests/test_param_shapes.py
+++ b/tests/test_param_shapes.py
@@ -5,42 +5,79 @@
import pytest
from flax.core import FrozenDict
-# isort: skip_file
-# pylint:disable=line-too-long
-from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload
-from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload
-from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload
-from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload
-from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload
-from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload
-# pylint:enable=line-too-long
+from algoperf.workloads.cifar.cifar_jax.workload import (
+ CifarWorkload as JaxCifarWorkload,
+)
+from algoperf.workloads.cifar.cifar_pytorch.workload import (
+ CifarWorkload as PyTorchCifarWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIWorkload as JaxFastMRIWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIWorkload as PyTorchFastMRIWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload as JaxImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload as PyTorchImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitWorkload as JaxImagenetViTWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitWorkload as PyTorchImagenetViTWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.mnist.mnist_jax.workload import (
+ MnistWorkload as JaxMnistWorkload,
+)
+from algoperf.workloads.mnist.mnist_pytorch.workload import (
+ MnistWorkload as PyTorchMnistWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgWorkload as JaxOgbgWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgWorkload as PyTorchOgbgWorkload,
+)
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkload as JaxWmtWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkload as PyTorchWmtWorkload,
+)
WORKLOADS = [
- 'cifar',
- 'criteo1tb',
- 'fastmri',
- 'imagenet_resnet',
- 'imagenet_vit',
- # TODO: make tests work for these.
- # 'librispeech_conformer',
- # 'librispeech_deepspeech',
- 'mnist',
- 'ogbg',
- 'wmt',
+ 'cifar',
+ 'criteo1tb',
+ 'fastmri',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ # TODO: make tests work for these.
+ # 'librispeech_conformer',
+ # 'librispeech_deepspeech',
+ 'mnist',
+ 'ogbg',
+ 'wmt',
]
@@ -56,9 +93,11 @@ def test_param_shapes(workload):
if isinstance(jax_workload_param_shapes, dict):
jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes)
jax_param_shapes = jax.tree_util.tree_leaves(
- jax_workload_param_shapes.unfreeze())
+ jax_workload_param_shapes.unfreeze()
+ )
pytorch_param_shapes = jax.tree_util.tree_leaves(
- pytorch_workload.param_shapes)
+ pytorch_workload.param_shapes
+ )
if workload == 'wmt':
# The PyTorch transformer for WMT is implemented with fused linear layers
# for the projection of QKV inside of the MultiheadAttention module.
@@ -74,8 +113,9 @@ def test_param_shapes(workload):
# Check if total number of params deduced from shapes match.
num_jax_params = 0
num_pytorch_params = 0
- for jax_shape, pytorch_shape in zip_longest(jax_param_shapes,
- pytorch_param_shapes):
+ for jax_shape, pytorch_shape in zip_longest(
+ jax_param_shapes, pytorch_param_shapes
+ ):
if jax_shape is not None:
num_jax_params += np.prod(jax_shape.shape_tuple)
if pytorch_shape is not None:
diff --git a/tests/test_param_types.py b/tests/test_param_types.py
index d3722ae86..9f14f7dd8 100644
--- a/tests/test_param_types.py
+++ b/tests/test_param_types.py
@@ -1,44 +1,80 @@
import jax
import pytest
-
from absl import logging
-from algoperf import spec
-# isort: skip_file
-# pylint:disable=line-too-long
-from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload
-from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload
-from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload
-from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload
-from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload
-from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload
-# pylint:enable=line-too-long
+from algoperf import spec
+from algoperf.workloads.cifar.cifar_jax.workload import (
+ CifarWorkload as JaxCifarWorkload,
+)
+from algoperf.workloads.cifar.cifar_pytorch.workload import (
+ CifarWorkload as PyTorchCifarWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIWorkload as JaxFastMRIWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIWorkload as PyTorchFastMRIWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload as JaxImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload as PyTorchImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitWorkload as JaxImagenetViTWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitWorkload as PyTorchImagenetViTWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.mnist.mnist_jax.workload import (
+ MnistWorkload as JaxMnistWorkload,
+)
+from algoperf.workloads.mnist.mnist_pytorch.workload import (
+ MnistWorkload as PyTorchMnistWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgWorkload as JaxOgbgWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgWorkload as PyTorchOgbgWorkload,
+)
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkload as JaxWmtWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkload as PyTorchWmtWorkload,
+)
WORKLOADS = [
- 'cifar',
- 'criteo1tb',
- 'fastmri',
- 'imagenet_resnet',
- 'imagenet_vit',
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'mnist',
- 'ogbg',
- 'wmt',
+ 'cifar',
+ 'criteo1tb',
+ 'fastmri',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'mnist',
+ 'ogbg',
+ 'wmt',
]
@@ -66,40 +102,32 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict):
# Sometimes one framework will implement QKV as a single parameter, so we need
# to make sure there are the same number of QKV params as Q, K, V.
num_qkv = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0),
+ 'pytorch': pytorch_param_types_dict.get(
+ spec.ParameterType.ATTENTION_QKV, 0
+ ),
}
num_kv = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
}
num_q = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
}
num_k = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
}
num_v = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
}
num_bias = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0),
+ 'pytorch': pytorch_param_types_dict.get(
+ spec.ParameterType.ATTENTION_BIAS, 0
+ ),
}
qkv_match = num_qkv['jax'] == num_qkv['pytorch']
kv_match = num_kv['jax'] == num_kv['pytorch']
@@ -108,24 +136,33 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict):
v_match = num_v['jax'] == num_v['pytorch']
bias_match = num_bias['jax'] == num_bias['pytorch']
qkv_match = (
- qkv_match and kv_match and q_match and k_match and v_match and bias_match)
+ qkv_match and kv_match and q_match and k_match and v_match and bias_match
+ )
# We subtract 2 * num_qkv from the number of biases because there are 2
# missing for each of q, k, v.
- jax_qkv_match = (
- num_q['pytorch'] == num_k['pytorch'] == num_v['pytorch'] == num_qkv['jax']
- and (num_qkv['jax'] != 0 and
- (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax']))
- pytorch_qkv_match = (
- num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv['pytorch'] and
- (num_qkv['pytorch'] != 0 and
- (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch']))
+ jax_qkv_match = num_q['pytorch'] == num_k['pytorch'] == num_v[
+ 'pytorch'
+ ] == num_qkv['jax'] and (
+ num_qkv['jax'] != 0
+ and (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax']
+ )
+ pytorch_qkv_match = num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv[
+ 'pytorch'
+ ] and (
+ num_qkv['pytorch'] != 0
+ and (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch']
+ )
pytorch_kv_match = (
- num_q['jax'] == num_k['jax'] == num_v['jax'] ==
- num_qkv['pytorch'] + num_kv['pytorch'] and
- num_q['pytorch'] == num_kv['pytorch'])
+ num_q['jax']
+ == num_k['jax']
+ == num_v['jax']
+ == num_qkv['pytorch'] + num_kv['pytorch']
+ and num_q['pytorch'] == num_kv['pytorch']
+ )
qkv_match = (
- qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match)
+ qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match
+ )
return qkv_match
@@ -137,7 +174,8 @@ def test_param_types(workload_name):
# Compare number of parameter tensors of both models.
jax_param_types = jax.tree_util.tree_leaves(jax_workload.model_params_types)
pytorch_param_types = jax.tree_util.tree_leaves(
- pytorch_workload.model_params_types)
+ pytorch_workload.model_params_types
+ )
jax_param_types_dict = count_param_types(jax_param_types)
pytorch_param_types_dict = count_param_types(pytorch_param_types)
@@ -161,30 +199,33 @@ def test_param_types(workload_name):
# Check if total number of each type match.
attention_keys = {
- spec.ParameterType.ATTENTION_QKV,
- spec.ParameterType.ATTENTION_KV,
- spec.ParameterType.ATTENTION_Q,
- spec.ParameterType.ATTENTION_K,
- spec.ParameterType.ATTENTION_V,
- spec.ParameterType.ATTENTION_BIAS,
+ spec.ParameterType.ATTENTION_QKV,
+ spec.ParameterType.ATTENTION_KV,
+ spec.ParameterType.ATTENTION_Q,
+ spec.ParameterType.ATTENTION_K,
+ spec.ParameterType.ATTENTION_V,
+ spec.ParameterType.ATTENTION_BIAS,
}
non_attention_keys = set(jax_param_types_dict.keys()).union(
- set(pytorch_param_types_dict.keys()))
+ set(pytorch_param_types_dict.keys())
+ )
non_attention_keys -= attention_keys
mismatches = ''
- mismatches += _count_mismatches(jax_param_types_dict,
- pytorch_param_types_dict,
- non_attention_keys)
- qkv_match = _check_attention_qkv_match(jax_param_types_dict,
- pytorch_param_types_dict)
+ mismatches += _count_mismatches(
+ jax_param_types_dict, pytorch_param_types_dict, non_attention_keys
+ )
+ qkv_match = _check_attention_qkv_match(
+ jax_param_types_dict, pytorch_param_types_dict
+ )
if not qkv_match:
- mismatches += _count_mismatches(jax_param_types_dict,
- pytorch_param_types_dict,
- attention_keys)
+ mismatches += _count_mismatches(
+ jax_param_types_dict, pytorch_param_types_dict, attention_keys
+ )
if mismatches:
raise ValueError(
- f'On workload {workload_name}, count mismatch: {mismatches}')
+ f'On workload {workload_name}, count mismatch: {mismatches}'
+ )
def get_workload(workload_name):
diff --git a/tests/test_ssim.py b/tests/test_ssim.py
index 920556964..dcb3f25e0 100644
--- a/tests/test_ssim.py
+++ b/tests/test_ssim.py
@@ -3,20 +3,20 @@
import os
from typing import Tuple
-from absl.testing import absltest
-from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
import torch
+from absl.testing import absltest, parameterized
from algoperf.pytorch_utils import pytorch_setup
-from algoperf.workloads.fastmri.fastmri_jax.ssim import \
- _uniform_filter as _jax_uniform_filter
+from algoperf.workloads.fastmri.fastmri_jax.ssim import (
+ _uniform_filter as _jax_uniform_filter,
+)
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim
-from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \
- _uniform_filter as _pytorch_uniform_filter
-from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \
- ssim as pytorch_ssim
+from algoperf.workloads.fastmri.fastmri_pytorch.ssim import (
+ _uniform_filter as _pytorch_uniform_filter,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim as pytorch_ssim
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
@@ -31,7 +31,7 @@ def _create_fake_im(height: int, width: int) -> Tuple[jnp.array, torch.Tensor]:
def _create_fake_batch(
- batch_size: int, height: int, width: int
+ batch_size: int, height: int, width: int
) -> Tuple[Tuple[jnp.array, jnp.array], Tuple[torch.Tensor, torch.Tensor]]:
logits = np.random.randn(batch_size, height, width)
targets = np.random.randn(batch_size, height, width)
@@ -47,9 +47,9 @@ class SSIMTest(parameterized.TestCase):
and PyTorch."""
@parameterized.named_parameters(
- dict(testcase_name='fastmri_im', height=320, width=320),
- dict(testcase_name='uneven_even_im', height=31, width=16),
- dict(testcase_name='even_uneven_im', height=42, width=53),
+ dict(testcase_name='fastmri_im', height=320, width=320),
+ dict(testcase_name='uneven_even_im', height=31, width=16),
+ dict(testcase_name='even_uneven_im', height=42, width=53),
)
def test_uniform_filter(self, height: int, width: int) -> None:
jax_im, pytorch_im = _create_fake_im(height, width)
@@ -58,12 +58,9 @@ def test_uniform_filter(self, height: int, width: int) -> None:
assert np.allclose(jax_result, torch_result, atol=1e-6)
@parameterized.named_parameters(
- dict(
- testcase_name='fastmri_batch', batch_size=256, height=320, width=320),
- dict(
- testcase_name='uneven_even_batch', batch_size=8, height=31, width=16),
- dict(
- testcase_name='even_uneven_batch', batch_size=8, height=42, width=53),
+ dict(testcase_name='fastmri_batch', batch_size=256, height=320, width=320),
+ dict(testcase_name='uneven_even_batch', batch_size=8, height=31, width=16),
+ dict(testcase_name='even_uneven_batch', batch_size=8, height=42, width=53),
)
def test_ssim(self, batch_size: int, height: int, width: int) -> None:
jax_inputs, pytorch_inputs = _create_fake_batch(batch_size, height, width)
@@ -71,9 +68,8 @@ def test_ssim(self, batch_size: int, height: int, width: int) -> None:
pytorch_ssim_result = pytorch_ssim(*pytorch_inputs)
self.assertEqual(jax_ssim_result.shape, pytorch_ssim_result.shape)
assert np.allclose(
- jax_ssim_result.sum().item(),
- pytorch_ssim_result.sum().item(),
- atol=1e-6)
+ jax_ssim_result.sum().item(), pytorch_ssim_result.sum().item(), atol=1e-6
+ )
if __name__ == '__main__':
diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py
index cea589202..8acfc855a 100644
--- a/tests/test_traindiffs.py
+++ b/tests/test_traindiffs.py
@@ -3,28 +3,26 @@
Run it as:
python3 test_traindiffs.py
"""
+
import pickle
import subprocess
-from subprocess import DEVNULL
-from subprocess import run
-from subprocess import STDOUT
+from subprocess import DEVNULL, STDOUT, run
from absl import flags
-from absl.testing import absltest
-from absl.testing import parameterized
+from absl.testing import absltest, parameterized
from numpy import allclose
FLAGS = flags.FLAGS
WORKLOADS = [
- 'imagenet_resnet',
- 'imagenet_vit',
- 'wmt',
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'fastmri',
- 'ogbg',
- 'criteo1tb'
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'wmt',
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'fastmri',
+ 'ogbg',
+ 'criteo1tb',
]
GLOBAL_BATCH_SIZE = 16
NUM_TRAIN_STEPS = 10
@@ -35,7 +33,6 @@
class ModelDiffTest(parameterized.TestCase):
-
@parameterized.named_parameters(*named_parameters)
def test_workload(self, workload):
# pylint: disable=line-too-long, unnecessary-lambda-assignment
@@ -50,24 +47,26 @@ def test_workload(self, workload):
pytorch_logs_path = '/tmp/pyt_log.pkl'
try:
run(
- f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}'
- f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
- shell=True,
- stdout=DEVNULL,
- stderr=STDOUT,
- check=True)
+ f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}'
+ f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
+ shell=True,
+ stdout=DEVNULL,
+ stderr=STDOUT,
+ check=True,
+ )
except subprocess.CalledProcessError as e:
- print("Error:", e)
+ print('Error:', e)
try:
run(
- f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}'
- f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
- shell=True,
- stdout=DEVNULL,
- stderr=STDOUT,
- check=True)
+ f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}'
+ f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
+ shell=True,
+ stdout=DEVNULL,
+ stderr=STDOUT,
+ check=True,
+ )
except subprocess.CalledProcessError as e:
- print("Error:", e)
+ print('Error:', e)
with open(jax_logs_path, 'rb') as f:
jax_results = pickle.load(f)
with open(pytorch_logs_path, 'rb') as f:
@@ -75,19 +74,25 @@ def test_workload(self, workload):
# PRINT RESULTS
eval_metric_key = next(
- iter(
- filter(lambda k: 'train' in k and 'loss' in k,
- jax_results['eval_results'][0])))
+ iter(
+ filter(
+ lambda k: 'train' in k and 'loss' in k, jax_results['eval_results'][0]
+ )
+ )
+ )
header = [
- 'Iter',
- 'Eval (jax)',
- 'Eval (torch)',
- 'Grad Norm (jax)',
- 'Grad Norm (torch)',
- 'Train Loss (jax)',
- 'Train Loss (torch)',
+ 'Iter',
+ 'Eval (jax)',
+ 'Eval (torch)',
+ 'Grad Norm (jax)',
+ 'Grad Norm (torch)',
+ 'Train Loss (jax)',
+ 'Train Loss (torch)',
]
- fmt = lambda l: '|' + '|'.join(map(lambda x: f'{x:^20s}', l)) + '|'
+
+ def fmt(line):
+ return '|' + '|'.join(map(lambda x: f'{x:^20s}', line)) + '|'
+
header = fmt(header)
pad = (len(header) - len((name))) // 2
print('=' * pad, name, '=' * (len(header) - len(name) - pad), sep='')
@@ -97,33 +102,41 @@ def test_workload(self, workload):
for i in range(NUM_TRAIN_STEPS):
rtol = 1
- row = map(lambda x: str(round(x, 5)),
- [
- jax_results['eval_results'][i][eval_metric_key],
- pytorch_results['eval_results'][i][eval_metric_key],
- jax_results['scalars'][i]['grad_norm'],
- pytorch_results['scalars'][i]['grad_norm'],
- jax_results['scalars'][i]['loss'],
- pytorch_results['scalars'][i]['loss'],
- ])
+ row = map(
+ lambda x: str(round(x, 5)),
+ [
+ jax_results['eval_results'][i][eval_metric_key],
+ pytorch_results['eval_results'][i][eval_metric_key],
+ jax_results['scalars'][i]['grad_norm'],
+ pytorch_results['scalars'][i]['grad_norm'],
+ jax_results['scalars'][i]['loss'],
+ pytorch_results['scalars'][i]['loss'],
+ ],
+ )
print(fmt([f'{i}', *row]))
print('=' * len(header))
self.assertTrue( # eval_results
- allclose(
- jax_results['eval_results'][i][eval_metric_key],
- pytorch_results['eval_results'][i][eval_metric_key],
- rtol=rtol))
+ allclose(
+ jax_results['eval_results'][i][eval_metric_key],
+ pytorch_results['eval_results'][i][eval_metric_key],
+ rtol=rtol,
+ )
+ )
self.assertTrue( # grad_norms
- allclose(
- jax_results['scalars'][i]['grad_norm'],
- pytorch_results['scalars'][i]['grad_norm'],
- rtol=rtol))
+ allclose(
+ jax_results['scalars'][i]['grad_norm'],
+ pytorch_results['scalars'][i]['grad_norm'],
+ rtol=rtol,
+ )
+ )
self.assertTrue( # loss
- allclose(
- jax_results['scalars'][i]['loss'],
- pytorch_results['scalars'][i]['loss'],
- rtol=rtol))
+ allclose(
+ jax_results['scalars'][i]['loss'],
+ pytorch_results['scalars'][i]['loss'],
+ rtol=rtol,
+ )
+ )
if __name__ == '__main__':
diff --git a/tests/test_version.py b/tests/test_version.py
index d1bfbd18f..69384953a 100644
--- a/tests/test_version.py
+++ b/tests/test_version.py
@@ -6,10 +6,10 @@
def test_version_attribute():
"""Check whether __version__ exists and is a valid string."""
- assert hasattr(algoperf, "__version__")
+ assert hasattr(algoperf, '__version__')
version = algoperf.__version__
assert isinstance(version, str)
- version_elements = version.split(".")
+ version_elements = version.split('.')
print(version_elements)
# Only check the first two elements, i.e. major, minor
# (patch is not checked as it is not required).
diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
index d44234927..60a1af2f2 100644
--- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
+++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
@@ -1,12 +1,13 @@
"""Tests for imagenet_resnet/imagenet_jax/workload.py."""
-from absl.testing import absltest
import jax
import jax.numpy as jnp
+from absl.testing import absltest
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload,
+)
def _pytree_total_diff(pytree_a, pytree_b):
@@ -32,42 +33,48 @@ def test_forward_pass(self):
# this function because we call it with a different combination of those two
# args each time. Can't call with kwargs.
pmapped_model_fn = jax.pmap(
- workload.model_fn,
- axis_name='batch',
- in_axes=(0, 0, 0, None, None, None),
- static_broadcasted_argnums=(3, 5))
+ workload.model_fn,
+ axis_name='batch',
+ in_axes=(0, 0, 0, None, None, None),
+ static_broadcasted_argnums=(3, 5),
+ )
logits, updated_batch_stats = pmapped_model_fn(
- model_params,
- {'inputs': first_input_batch},
- batch_stats,
- spec.ForwardPassMode.TRAIN,
- rng,
- True)
+ model_params,
+ {'inputs': first_input_batch},
+ batch_stats,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ True,
+ )
self.assertEqual(logits.shape, expected_logits_shape)
# Test that batch stats are updated.
self.assertNotEqual(
- _pytree_total_diff(batch_stats, updated_batch_stats), 0.0)
+ _pytree_total_diff(batch_stats, updated_batch_stats), 0.0
+ )
second_input_batch = jax.random.normal(data_rngs[1], shape=input_shape)
# Test that batch stats are not updated when we say so.
_, same_batch_stats = pmapped_model_fn(
- model_params,
- {'inputs': second_input_batch},
- updated_batch_stats,
- spec.ForwardPassMode.TRAIN,
- rng,
- False)
+ model_params,
+ {'inputs': second_input_batch},
+ updated_batch_stats,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ False,
+ )
self.assertEqual(
- _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0)
+ _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0
+ )
# Test eval model.
logits, _ = pmapped_model_fn(
- model_params,
- {'inputs': second_input_batch},
- batch_stats,
- spec.ForwardPassMode.EVAL,
- rng,
- False)
+ model_params,
+ {'inputs': second_input_batch},
+ batch_stats,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ False,
+ )
self.assertEqual(logits.shape, expected_logits_shape)