diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 992496b69..4308e4201 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -3,7 +3,7 @@ name: Linting on: [push, pull_request] jobs: - pylint: + ruff-linting: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -11,19 +11,15 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.11.10 - - name: Install pylint + - name: Install ruff run: | python -m pip install --upgrade pip - pip install pylint==2.16.1 - - name: Run pylint + pip install ruff==0.12.0 + - name: Run ruff linter run: | - pylint algoperf - pylint reference_algorithms - pylint prize_qualification_baselines - pylint submission_runner.py - pylint tests + ruff check - isort: + ruff-formatter: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -31,26 +27,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: 3.11.10 - - name: Install isort + - name: Install ruff run: | python -m pip install --upgrade pip - pip install isort==5.12.0 - - name: Run isort + pip install ruff==0.12.0 + - name: Run ruff formatter run: | - isort . --check --diff + ruff format --check - yapf: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.11.10 - uses: actions/setup-python@v2 - with: - python-version: 3.11.10 - - name: Install yapf - run: | - python -m pip install --upgrade pip - pip install yapf==0.32 toml - - name: Run yapf - run: | - yapf . --diff --recursive diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc8f13d25..f2d684c53 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,10 @@ repos: - - repo: https://github.com/google/yapf - rev: v0.32.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.12.0 hooks: - - id: yapf - args: ["--in-place", "--parallel", "--verbose", "--recursive"] - - repo: https://github.com/pycqa/isort - rev: 5.10.1 - hooks: - - id: isort - - repo: https://github.com/pycqa/pylint - rev: v2.16.1 - hooks: - - id: pylint + # Run the linter (don't change files). + - id: ruff-check + # Run the formatter (don't change files). + - id: ruff-format + args: ["--check"] diff --git a/README.md b/README.md index 8e470266d..0666d21d5 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,11 @@ Benchmark/Results Paper

-[![CI](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml/badge.svg)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml) -[![Lint](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml/badge.svg)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml) -[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/mlcommons/algorithmic-efficiency/blob/main/LICENSE.md) -[![Code style: yapf](https://img.shields.io/badge/code%20style-yapf-orange)](https://github.com/google/yapf) -[![Discord](https://dcbadge.vercel.app/api/server/5FPXK7SMt6?style=flat)](https://discord.gg/5FPXK7SMt6) +[![CI Status](https://img.shields.io/github/actions/workflow/status/mlcommons/algorithmic-efficiency/CI.yml?style=flat-square&logo=github&label=CI)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml) +[![Linting Status](https://img.shields.io/github/actions/workflow/status/mlcommons/algorithmic-efficiency/linting.yml?style=flat-square&logo=github&label=Linting)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml) +[![Code Style Ruff](https://img.shields.io/badge/Code%20Style-Ruff-brightgreen?style=flat-square&logo=ruff)](https://github.com/astral-sh/ruff) +[![GitHub License](https://img.shields.io/github/license/mlcommons/algorithmic-efficiency?style=flat-square&label=License)](LICENSE.md) +[![Discord](https://dcbadge.limes.pink/api/server/5FPXK7SMt6?style=flat-square)](https://discord.gg/5FPXK7SMt6) --- @@ -28,11 +28,12 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock --- -> This is the repository for the *AlgoPerf: Training Algorithms benchmark* measuring neural network training speedups due to algorithmic improvements. +> This is the repository for the _AlgoPerf: Training Algorithms benchmark_ measuring neural network training speedups due to algorithmic improvements. > It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/). > This repository holds the benchmark code, the benchmark's [**technical documentation**](/docs/DOCUMENTATION.md) and [**getting started guides**](/docs/GETTING_STARTED.md). For a detailed description of the benchmark design, see our [**introductory paper**](https://arxiv.org/abs/2306.07179), for the results of the inaugural competition see our [**results paper**](https://openreview.net/forum?id=CtM5xjRSfm). > > **See our [AlgoPerf Leaderboard](https://github.com/mlcommons/submissions_algorithms) for the latest results of the benchmark and to submit your algorithm.** + --- > [!IMPORTANT] @@ -50,14 +51,13 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock ## Installation -> [!TIP] -> **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/). +> [!TIP] > **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/). You can install this package and dependencies in a [Python virtual environment](/docs/GETTING_STARTED.md#python-virtual-environment) or use a [Docker/Singularity/Apptainer container](/docs/GETTING_STARTED.md#docker) (recommended). We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments. Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document. -*TL;DR to install the Jax version for GPU run:* +_TL;DR to install the Jax version for GPU run:_ ```bash pip3 install -e '.[pytorch_cpu]' @@ -65,7 +65,7 @@ pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax pip3 install -e '.[full]' ``` -*TL;DR to install the PyTorch version for GPU run:* +_TL;DR to install the PyTorch version for GPU run:_ ```bash pip3 install -e '.[jax_cpu]' @@ -77,7 +77,7 @@ pip3 install -e '.[full]' For detailed instructions on developing your own algorithm in the benchmark see the [Getting Started](/docs/GETTING_STARTED.md) document. -*TL;DR running a JAX workload:* +_TL;DR running a JAX workload:_ ```bash python3 submission_runner.py \ @@ -89,7 +89,7 @@ python3 submission_runner.py \ --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ``` -*TL;DR running a PyTorch workload:* +_TL;DR running a PyTorch workload:_ ```bash python3 submission_runner.py \ @@ -117,17 +117,15 @@ Our [**Contributing**](/docs/CONTRIBUTING.md) document provides further MLCommon ## License -The *AlgoPerf* codebase is licensed under the [Apache License 2.0](/LICENSE.md). +The _AlgoPerf_ codebase is licensed under the [Apache License 2.0](/LICENSE.md). ## Paper and Citing the AlgoPerf Benchmark -In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the *AlgoPerf: Training Algorithms* benchmark. +In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the _AlgoPerf: Training Algorithms_ benchmark. -If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads, please consider citing our paper: +If you are using the _AlgoPerf benchmark_, its codebase, baselines, or workloads, please consider citing our paper: -> [Dahl, Schneider, Nado, et al.
-> **Benchmarking Neural Network Training Algorithms**
-> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179) +> [Dahl, Schneider, Nado, et al.
> **Benchmarking Neural Network Training Algorithms**
> _arXiv 2306.07179_](http://arxiv.org/abs/2306.07179) ```bibtex @Misc{Dahl2023AlgoPerf, @@ -139,10 +137,9 @@ If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads } ``` -If you use the results from the first *AlgoPerf competition*, please consider citing the results paper, as well as the relevant submissions: +If you use the results from the first _AlgoPerf competition_, please consider citing the results paper, as well as the relevant submissions: -> [Kasimbeg, Schneider, Eschenhagen, et al.
-> **Accelerating neural network training: An analysis of the AlgoPerf competition**
+> [Kasimbeg, Schneider, Eschenhagen, et al.
> **Accelerating neural network training: An analysis of the AlgoPerf competition**
> ICLR 2025](https://openreview.net/forum?id=CtM5xjRSfm) ```bibtex diff --git a/algoperf/__init__.py b/algoperf/__init__.py index 7d54f8290..5ecee05af 100644 --- a/algoperf/__init__.py +++ b/algoperf/__init__.py @@ -2,4 +2,4 @@ from ._version import version as __version__ -__all__ = ["__version__"] +__all__ = ['__version__'] diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..f8cc40599 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -7,37 +7,41 @@ import os from typing import Sequence, Tuple +import jax +import numpy as np +import torch from absl import logging from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax -import numpy as np from tensorflow.io import gfile # pytype: disable=import-error -import torch from algoperf import spec from algoperf.pytorch_utils import pytorch_setup _, _, DEVICE, _ = pytorch_setup() -CheckpointReturn = Tuple[spec.OptimizerState, - spec.ParameterContainer, - spec.ModelAuxiliaryState, - dict, - list, - int, - int] - - -def maybe_restore_checkpoint(framework: str, - optimizer_state: spec.OptimizerState, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - train_state: dict, - eval_results: list, - global_step: int, - preemption_count: int, - checkpoint_dir: str) -> CheckpointReturn: +CheckpointReturn = Tuple[ + spec.OptimizerState, + spec.ParameterContainer, + spec.ModelAuxiliaryState, + dict, + list, + int, + int, +] + + +def maybe_restore_checkpoint( + framework: str, + optimizer_state: spec.OptimizerState, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + train_state: dict, + eval_results: list, + global_step: int, + preemption_count: int, + checkpoint_dir: str, +) -> CheckpointReturn: """Optionally restores from a checkpoint. The checkpoint logic is as follows: if there is a checkpoint in @@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str, uninitialized_global_step = -1 uninitialized_preemption_count = -1 checkpoint_state = { - 'model_params': model_params, - 'optimizer_state': opt_state, - 'model_state': model_state, - 'train_state': train_state, - 'eval_results': None, - 'global_step': uninitialized_global_step, - 'preemption_count': uninitialized_preemption_count, + 'model_params': model_params, + 'optimizer_state': opt_state, + 'model_state': model_state, + 'train_state': train_state, + 'eval_results': None, + 'global_step': uninitialized_global_step, + 'preemption_count': uninitialized_preemption_count, } if framework == 'jax': latest_ckpt = flax_checkpoints.restore_checkpoint( - checkpoint_dir, target=checkpoint_state) - save_path = os.path.join(checkpoint_dir, - 'checkpoint_' + str(latest_ckpt['global_step'])) + checkpoint_dir, target=checkpoint_state + ) + save_path = os.path.join( + checkpoint_dir, 'checkpoint_' + str(latest_ckpt['global_step']) + ) else: latest_ckpt = checkpoint_state save_path = latest_checkpoint(checkpoint_dir) @@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str, found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step if not found_checkpoint: - return (optimizer_state, - model_params, - model_state, - train_state, - eval_results, - global_step, - preemption_count) + return ( + optimizer_state, + model_params, + model_state, + train_state, + eval_results, + global_step, + preemption_count, + ) # If there's the latest checkpoint in the checkpoint_dir, restore from that. if framework == 'jax': checkpoint_state = replicate_checkpoint( - latest_ckpt, - pytree_keys=[ - 'optimizer_state', - 'model_params', - 'model_state', - ]) - checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'], - opt_update_fn) + latest_ckpt, + pytree_keys=[ + 'optimizer_state', + 'model_params', + 'model_state', + ], + ) + checkpoint_state['optimizer_state'] = ( + checkpoint_state['optimizer_state'], + opt_update_fn, + ) checkpoint_state['eval_results'] = [ - (value, key) for key, value in latest_ckpt['eval_results'].items() + (value, key) for key, value in latest_ckpt['eval_results'].items() ] else: checkpoint_state = latest_ckpt if isinstance( - model_params, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + model_params, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel), + ): model_params = model_params.module model_params.load_state_dict(checkpoint_state['model_params']) checkpoint_state['model_params'] = model_params for key in optimizer_state.keys(): optimizer_state[key].load_state_dict( - checkpoint_state['optimizer_state'][key]) + checkpoint_state['optimizer_state'][key] + ) checkpoint_state['optimizer_state'][key] = optimizer_state[key] logging.info(f'Loaded checkpoint from {save_path}.') - return (checkpoint_state['optimizer_state'], - checkpoint_state['model_params'], - checkpoint_state['model_state'], - checkpoint_state['train_state'], - list(checkpoint_state['eval_results']), - checkpoint_state['global_step'], - checkpoint_state['preemption_count'] + 1) - - -def replicate_checkpoint(latest: dict, - pytree_keys: Sequence[str], - replicate: bool = True) -> dict: + return ( + checkpoint_state['optimizer_state'], + checkpoint_state['model_params'], + checkpoint_state['model_state'], + checkpoint_state['train_state'], + list(checkpoint_state['eval_results']), + checkpoint_state['global_step'], + checkpoint_state['preemption_count'] + 1, + ) + + +def replicate_checkpoint( + latest: dict, pytree_keys: Sequence[str], replicate: bool = True +) -> dict: """Restores from the provided checkpoint. Args: @@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict, return pytree -def save_checkpoint(framework: str, - optimizer_state: spec.OptimizerState, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - train_state: dict, - eval_results: list, - global_step: int, - preemption_count: int, - checkpoint_dir: str, - save_intermediate_checkpoints: bool) -> None: +def save_checkpoint( + framework: str, + optimizer_state: spec.OptimizerState, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + train_state: dict, + eval_results: list, + global_step: int, + preemption_count: int, + checkpoint_dir: str, + save_intermediate_checkpoints: bool, +) -> None: """Save the checkpoint in `checkpoint_dir`. Args: @@ -199,8 +216,9 @@ def save_checkpoint(framework: str, model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( - model_params, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + model_params, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel), + ): model_params = model_params.module model_params = model_params.state_dict() optimizer_state_dict = {} @@ -209,33 +227,36 @@ def save_checkpoint(framework: str, optimizer_state_dict[key] = optimizer_state[key].state_dict() else: logging.warning( - f'The optimizer state for key {key} is not saved, because ' - f'{type(optimizer_state[key])} has not implemented a state_dict() ' - 'method.') + f'The optimizer state for key {key} is not saved, because ' + f'{type(optimizer_state[key])} has not implemented a state_dict() ' + 'method.' + ) opt_state = optimizer_state_dict checkpoint_state = { - 'model_params': model_params, - 'optimizer_state': opt_state, - 'model_state': model_state, - 'train_state': train_state, - 'eval_results': tuple(eval_results), - 'global_step': global_step, - 'preemption_count': preemption_count, + 'model_params': model_params, + 'optimizer_state': opt_state, + 'model_state': model_state, + 'train_state': train_state, + 'eval_results': tuple(eval_results), + 'global_step': global_step, + 'preemption_count': preemption_count, } save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}') if framework == 'jax': flax_checkpoints.save_checkpoint( - checkpoint_dir, - target=checkpoint_state, - step=global_step, - overwrite=True, - keep=np.inf if save_intermediate_checkpoints else 1) + checkpoint_dir, + target=checkpoint_state, + step=global_step, + overwrite=True, + keep=np.inf if save_intermediate_checkpoints else 1, + ) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( - os.path.join(checkpoint_dir, 'checkpoint_*')) + os.path.join(checkpoint_dir, 'checkpoint_*') + ) for path in checkpoint_files: logging.info('Removing checkpoint at %s', path) gfile.rmtree(path) diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..f08d9d2db 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -7,17 +7,16 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from torch.utils.data import DataLoader -from torch.utils.data import DistributedSampler -from torch.utils.data import Sampler +from torch.utils.data import DataLoader, DistributedSampler, Sampler from algoperf import spec def shard_and_maybe_pad_np( - batch: Dict[str, spec.Tensor], - padding_value: int = 0, - global_batch_size: Optional[int] = None) -> Dict[str, spec.Tensor]: + batch: Dict[str, spec.Tensor], + padding_value: int = 0, + global_batch_size: Optional[int] = None, +) -> Dict[str, spec.Tensor]: """Prepare tf data for JAX or PyTorch DDP. Convert an input batch from tf Tensors to numpy arrays, pad it with @@ -26,11 +25,13 @@ def shard_and_maybe_pad_np( """ local_device_count = max(torch.cuda.device_count(), jax.local_device_count()) inputs = batch['inputs'] - current_batch_size = inputs[0].shape[0] if isinstance( - inputs, tuple) else inputs.shape[0] + current_batch_size = ( + inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] + ) if global_batch_size is not None: - assert global_batch_size >= current_batch_size, \ - 'global_batch_size must be larger than or equal to current_batch_size.' + assert global_batch_size >= current_batch_size, ( + 'global_batch_size must be larger than or equal to current_batch_size.' + ) # Always pad to global_batch_size if it is provided. pad_to_global_batch_size = global_batch_size > current_batch_size else: @@ -43,7 +44,8 @@ def shard_and_maybe_pad_np( pad_size = local_device_count - remainder_size targets = batch['targets'] targets_shape = tuple( - targets[0].shape if isinstance(targets, tuple) else targets.shape) + targets[0].shape if isinstance(targets, tuple) else targets.shape + ) # We need a 2d mask for WMT. mask_shape = targets_shape if len(targets_shape) < 3 else targets_shape[0] # Get weights from batch if there are any. @@ -68,9 +70,9 @@ def _prepare(x): return jax.tree.map(_prepare, batch) -def pad(tensor: np.ndarray, - pad_size: int, - padding_value: int = 0) -> np.ndarray: +def pad( + tensor: np.ndarray, pad_size: int, padding_value: int = 0 +) -> np.ndarray: if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) padding = np.full(pad_size, padding_value, dtype=tensor.dtype) @@ -78,8 +80,9 @@ def pad(tensor: np.ndarray, return padded_tensor -def mixup_pytorch(batch: Tuple[spec.Tensor, spec.Tensor], - alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]: +def mixup_pytorch( + batch: Tuple[spec.Tensor, spec.Tensor], alpha: float = 0.2 +) -> Tuple[spec.Tensor, spec.Tensor]: inputs, targets = batch # Transform to one-hot targets. targets = F.one_hot(targets, num_classes=1000) @@ -144,12 +147,14 @@ class DistributedEvalSampler(Sampler): ... train(loader) """ - def __init__(self, - dataset: torch.utils.data.Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = False, - seed: int = 0) -> None: + def __init__( + self, + dataset: torch.utils.data.Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + seed: int = 0, + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError('Requires distributed package to be available.') @@ -165,7 +170,7 @@ def __init__(self, # true value without extra samples self.total_size = len(self.dataset) indices = list(range(self.total_size)) - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] # true value without extra samples self.num_samples = len(indices) @@ -182,7 +187,7 @@ def __iter__(self) -> Iterable[int]: indices = list(range(len(self.dataset))) # Subsample. - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) @@ -203,11 +208,13 @@ def set_epoch(self, epoch: int) -> None: # Modified from github.com/pytorch/pytorch/issues/23900#issuecomment-518858050. -def cycle(iterable: Iterable, - keys: Tuple[str, ...] = ('inputs', 'targets'), - custom_sampler: bool = False, - use_mixup: bool = False, - mixup_alpha: float = 0.2) -> Iterable: +def cycle( + iterable: Iterable, + keys: Tuple[str, ...] = ('inputs', 'targets'), + custom_sampler: bool = False, + use_mixup: bool = False, + mixup_alpha: float = 0.2, +) -> Iterable: iterator = iter(iterable) epoch = 0 while True: @@ -229,11 +236,9 @@ def cycle(iterable: Iterable, # github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ # ConvNets/image_classification/dataloaders.py class PrefetchedWrapper: - - def __init__(self, - dataloader: DataLoader, - device: torch.device, - start_epoch: int = 0) -> None: + def __init__( + self, dataloader: DataLoader, device: torch.device, start_epoch: int = 0 + ) -> None: self.dataloader = dataloader self.epoch = start_epoch self.device = device @@ -254,11 +259,12 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]: for next_inputs, next_targets in self.dataloader: with torch.cuda.stream(stream): next_inputs = next_inputs.to( - self.device, dtype=torch.float, non_blocking=True) + self.device, dtype=torch.float, non_blocking=True + ) next_targets = next_targets.to(self.device, non_blocking=True) if not first: - yield inputs, targets + yield inputs, targets # noqa: F821 else: first = False diff --git a/algoperf/halton.py b/algoperf/halton.py index 1f36b07bf..08c5466f1 100644 --- a/algoperf/halton.py +++ b/algoperf/halton.py @@ -36,10 +36,12 @@ def _is_prime(n: int) -> bool: return all(n % i != 0 for i in range(2, int(n**0.5) + 1)) and n != 2 -def _generate_dim(num_samples: int, - base: int, - per_dim_shift: bool, - shuffled_seed_sequence: List[int]) -> List[float]: +def _generate_dim( + num_samples: int, + base: int, + per_dim_shift: bool, + shuffled_seed_sequence: List[int], +) -> List[float]: """Generate `num_samples` from a Van der Corput sequence with base `base`. Args: @@ -59,8 +61,9 @@ def _generate_dim(num_samples: int, ValueError: if `base` is negative or not prime. """ if base < 0 or not _is_prime(base): - raise ValueError('Each Van der Corput sequence requires a prime `base`, ' - f'received {base}.') + raise ValueError( + f'Each Van der Corput sequence requires a prime `base`, received {base}.' + ) rng = random.RandomState(base) if shuffled_seed_sequence is None: @@ -76,7 +79,7 @@ def _generate_dim(num_samples: int, dim_sequence = [] for i in range(1, num_samples + 1): - num = 0. + num = 0.0 denominator = base while i: num += shuffled_seed_sequence[i % base] / denominator @@ -91,13 +94,15 @@ def _generate_dim(num_samples: int, Matrix = List[List[int]] -def generate_sequence(num_samples: int, - num_dims: int, - skip: int = 100, - per_dim_shift: bool = True, - shuffle_sequence: bool = True, - primes: Sequence[int] = None, - shuffled_seed_sequence: Matrix = None) -> Matrix: +def generate_sequence( + num_samples: int, + num_dims: int, + skip: int = 100, + per_dim_shift: bool = True, + shuffle_sequence: bool = True, + primes: Sequence[int] = None, + shuffled_seed_sequence: Matrix = None, +) -> Matrix: """Generate `num_samples` from a Halton sequence of dimension `num_dims`. Each dimension is generated independently from a shuffled Van der Corput @@ -140,25 +145,29 @@ def generate_sequence(num_samples: int, if primes is not None and len(primes) != num_dims: raise ValueError( - 'If passing in a sequence of primes it must be the same length as ' - f'num_dims={num_dims}, received {primes} (len {len(primes)}).') + 'If passing in a sequence of primes it must be the same length as ' + f'num_dims={num_dims}, received {primes} (len {len(primes)}).' + ) if shuffled_seed_sequence is not None: if len(shuffled_seed_sequence) != num_dims: raise ValueError( - 'If passing in `shuffled_seed_sequence` it must be the same length ' - f'as num_dims={num_dims}, received {shuffled_seed_sequence} ' - f'(len {len(shuffled_seed_sequence)}).') + 'If passing in `shuffled_seed_sequence` it must be the same length ' + f'as num_dims={num_dims}, received {shuffled_seed_sequence} ' + f'(len {len(shuffled_seed_sequence)}).' + ) for d in range(num_dims): if len(shuffled_seed_sequence[d]) != primes[d]: raise ValueError( - 'If passing in `shuffled_seed_sequence` it must have element `{d}` ' - 'be a sequence of length `primes[{d}]`={expected}, received ' - '{actual} (len {length})'.format( - d=d, - expected=primes[d], - actual=shuffled_seed_sequence[d], - length=shuffled_seed_sequence[d])) + 'If passing in `shuffled_seed_sequence` it must have element `{d}` ' + 'be a sequence of length `primes[{d}]`={expected}, received ' + '{actual} (len {length})'.format( + d=d, + expected=primes[d], + actual=shuffled_seed_sequence[d], + length=shuffled_seed_sequence[d], + ) + ) if primes is None: primes = [] @@ -166,7 +175,7 @@ def generate_sequence(num_samples: int, while len(primes) < num_dims + 1: primes = generate_primes(1000 * prime_attempts) prime_attempts += 1 - primes = primes[-num_dims - 1:-1] + primes = primes[-num_dims - 1 : -1] # Skip the first `skip` points in the sequence because they can have unwanted # correlations. @@ -179,10 +188,11 @@ def generate_sequence(num_samples: int, else: dim_shuffled_seed_sequence = shuffled_seed_sequence[d] dim_sequence = _generate_dim( - num_samples=num_samples, - base=primes[d], - shuffled_seed_sequence=dim_shuffled_seed_sequence, - per_dim_shift=per_dim_shift) + num_samples=num_samples, + base=primes[d], + shuffled_seed_sequence=dim_shuffled_seed_sequence, + per_dim_shift=per_dim_shift, + ) dim_sequence = dim_sequence[skip:] halton_sequence.append(dim_sequence) @@ -195,29 +205,29 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: str, - min_val: float, - max_val: float, - scaling: str, - halton_point: float) -> Tuple[str, float]: +def _generate_double_point( + name: str, min_val: float, max_val: float, scaling: str, halton_point: float +) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: raise ValueError( - 'Only log or linear scaling is supported for floating point ' - f'parameters. Received {scaling}.') + 'Only log or linear scaling is supported for floating point ' + f'parameters. Received {scaling}.' + ) if scaling == 'log': # To transform from [0, 1] to [min_val, max_val] on a log scale we do: # min_val * exp(x * log(max_val / min_val)). - rescaled_value = ( - min_val * math.exp(halton_point * math.log(max_val / min_val))) + rescaled_value = min_val * math.exp( + halton_point * math.log(max_val / min_val) + ) else: rescaled_value = halton_point * (max_val - min_val) + min_val return name, rescaled_value -def _generate_discrete_point(name: str, - feasible_points: Sequence[Any], - halton_point: float) -> Any: +def _generate_discrete_point( + name: str, feasible_points: Sequence[Any], halton_point: float +) -> Any: """Generate a discrete hyperparameter value from a Halton sequence point.""" index = int(math.floor(halton_point * len(feasible_points))) return name, feasible_points[index] @@ -236,27 +246,23 @@ def interval(start: int, end: int) -> Tuple[int, int]: def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints - return functools.partial(_generate_double_point, - name, - min_val, - max_val, - 'log') + return functools.partial( + _generate_double_point, name, min_val, max_val, 'log' + ) def uniform( - name: str, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + name: str, search_points: Union[_DiscretePoints, Tuple[int, int]] +) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): - return functools.partial(_generate_discrete_point, - name, - search_points.feasible_points) + return functools.partial( + _generate_discrete_point, name, search_points.feasible_points + ) min_val, max_val = search_points - return functools.partial(_generate_double_point, - name, - min_val, - max_val, - 'linear') + return functools.partial( + _generate_double_point, name, min_val, max_val, 'linear' + ) def product(sweeps: Sequence[_SweepSequence]) -> _SweepSequence: @@ -277,9 +283,10 @@ def sweep(name, feasible_points: Sequence[Any]) -> _SweepSequence: return [{name: x} for x in feasible_points.feasible_points] -def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, - _SweepSequence]], - length: int) -> _SweepSequence: +def zipit( + generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, _SweepSequence]], + length: int, +) -> _SweepSequence: """Zip together a list of hyperparameter generators. Args: @@ -302,7 +309,8 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, hyperparameter name from generator_fns_or_sweeps. """ halton_sequence = generate_sequence( - num_samples=length, num_dims=len(generator_fns_or_sweeps)) + num_samples=length, num_dims=len(generator_fns_or_sweeps) + ) # A List[Dict] of hyperparameter names to sweep values. hyperparameter_sweep = [] for trial_index in range(length): @@ -326,8 +334,9 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, _ListSearchSpace = List[Dict[str, Union[str, float, Sequence]]] -def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace], - num_trials: int) -> List[collections.namedtuple]: +def generate_search( + search_space: Union[_DictSearchSpace, _ListSearchSpace], num_trials: int +) -> List[collections.namedtuple]: """Generate a random search with the given bounds and scaling. Args:linear @@ -352,8 +361,9 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace], else: raise AttributeError('tuning_search_space should either be a dict or list.') - named_tuple_class = collections.namedtuple('Hyperparameters', - all_hyperparameter_names) + named_tuple_class = collections.namedtuple( + 'Hyperparameters', all_hyperparameter_names + ) if isinstance(search_space, dict): hyperparameter_generators = [] @@ -367,16 +377,18 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace], generator_fn = uniform(name, interval(space['min'], space['max'])) hyperparameter_generators.append(generator_fn) return [ - named_tuple_class(**p) - for p in zipit(hyperparameter_generators, num_trials) + named_tuple_class(**p) + for p in zipit(hyperparameter_generators, num_trials) ] else: hyperparameters = [] updated_num_trials = min(num_trials, len(search_space)) if num_trials != len(search_space): - logging.info(f'--num_tuning_trials was set to {num_trials}, but ' - f'{len(search_space)} trial(s) found in the JSON file. ' - f'Updating --num_tuning_trials to {updated_num_trials}.') + logging.info( + f'--num_tuning_trials was set to {num_trials}, but ' + f'{len(search_space)} trial(s) found in the JSON file. ' + f'Updating --num_tuning_trials to {updated_num_trials}.' + ) for trial in search_space: hyperparameters.append(named_tuple_class(**trial)) return hyperparameters[:updated_num_trials] diff --git a/algoperf/init_utils.py b/algoperf/init_utils.py index 185480cc7..c66a0be20 100644 --- a/algoperf/init_utils.py +++ b/algoperf/init_utils.py @@ -12,7 +12,7 @@ def pytorch_default_init(module: nn.Module) -> None: # Perform lecun_normal initialization. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) - std = math.sqrt(1. / fan_in) / .87962566103423978 + std = math.sqrt(1.0 / fan_in) / 0.87962566103423978 nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std) if module.bias is not None: - nn.init.constant_(module.bias, 0.) + nn.init.constant_(module.bias, 0.0) diff --git a/algoperf/interop_utils.py b/algoperf/interop_utils.py index 0c6535d7a..c30d0cf3b 100644 --- a/algoperf/interop_utils.py +++ b/algoperf/interop_utils.py @@ -6,7 +6,8 @@ def jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False) -> spec.Tensor: return torch.utils.dlpack.from_dlpack( - jax.dlpack.to_dlpack(x, take_ownership=take_ownership)) + jax.dlpack.to_dlpack(x, take_ownership=take_ownership) + ) def pytorch_to_jax(x: torch.Tensor) -> spec.Tensor: diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py new file mode 100644 index 000000000..dab338328 --- /dev/null +++ b/algoperf/jax_utils.py @@ -0,0 +1,129 @@ +from collections.abc import Sequence + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen.module import Module, compact, merge_param +from flax.typing import PRNGKey +from jax import lax, random + + +# Custom Layers +class Dropout(Module): + # pylint: disable=line-too-long + """Create a dropout layer. + Forked from + https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. + The reference dropout implementation is modified support changes + to dropout rate during training by: + 1) adding rate argument to the __call__ method. + 2) removing the if-else condition to check for edge cases, which + will trigger a recompile for jitted code. + + .. note:: + When using :meth:`Module.apply() `, make sure + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for + variable initialization. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class MLP(nn.Module): + ... @nn.compact + ... def __call__(self, x, train): + ... x = nn.Dense(4)(x) + ... x = nn.Dropout(0.5, deterministic=not train)(x) + ... return x + + >>> model = MLP() + >>> x = jnp.ones((1, 3)) + >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout + >>> model.apply(variables, x, train=False) # don't use dropout + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) + >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rng_collection: the rng collection name to use when requesting an rng + key. + """ + + rate: float | None = None + broadcast_dims: Sequence[int] = () + deterministic: bool | None = None + rng_collection: str = 'dropout' + legacy: bool = False + + @compact + def __call__( + self, + inputs, + deterministic: bool | None = None, + rate: float | None = None, + rng: PRNGKey | None = None, + ): + """Applies a random dropout mask to the input. + + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` + and masked, whereas if true, no mask is applied and the inputs are + returned as is. + rate: the dropout probability. (_not_ the keep rate!) + rng: an optional PRNGKey used as the random key, if not specified, + one will be generated using ``make_rng`` with the + ``rng_collection`` name. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = merge_param( + 'deterministic', self.deterministic, deterministic + ) + + # Override self.rate if rate is passed to __call__ + if rate is None: + rate = self.rate + + if self.legacy: + if rate == 0.0: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if rate == 1.0: + return jnp.zeros_like(inputs) + + if deterministic: + return inputs + + keep_prob = 1.0 - rate + if rng is None: + rng = self.make_rng(self.rng_collection) + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + +# Utilities for debugging +def print_jax_model_summary(model, fake_inputs): + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, + 'force_jupyter': False, + 'width': 240, + }, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py index c988956dc..17eea74a6 100644 --- a/algoperf/logger_utils.py +++ b/algoperf/logger_utils.py @@ -11,12 +11,12 @@ import sys from typing import Any, Dict, Optional -from absl import flags -from clu import metric_writers import GPUtil import pandas as pd import psutil import torch.distributed as dist +from absl import flags +from clu import metric_writers from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -37,12 +37,12 @@ def makedir(dir_name: str, exist_ok: bool = True) -> None: def get_log_dir( - experiment_dir: str, - workload: spec.Workload, - framework: str, - experiment_name: str, - resume_last_run: bool, - overwrite: bool, + experiment_dir: str, + workload: spec.Workload, + framework: str, + experiment_name: str, + resume_last_run: bool, + overwrite: bool, ) -> Optional[str]: # Construct path to experiment workload directory. experiment_dir = os.path.expanduser(experiment_dir) @@ -50,26 +50,29 @@ def get_log_dir( if experiment_name is None: experiment_path = os.path.join(experiment_dir, workload_dir_name) else: - experiment_path = os.path.join(experiment_dir, - experiment_name, - workload_dir_name) + experiment_path = os.path.join( + experiment_dir, experiment_name, workload_dir_name + ) if os.path.exists(experiment_path): if overwrite: logging.info( - f'Removing existing experiment directory {experiment_path} because ' - '--overwrite was set.') + f'Removing existing experiment directory {experiment_path} because ' + '--overwrite was set.' + ) if RANK == 0: shutil.rmtree(experiment_path) elif resume_last_run: logging.info( - f'Resuming from experiment directory {experiment_path} because ' - '--resume_last_run was set.') + f'Resuming from experiment directory {experiment_path} because ' + '--resume_last_run was set.' + ) else: if RANK == 0: resume = input( - 'Found existing experiment dir with the same name: {}. Do you wish ' - 'to resume training from this dir? [y/N]:'.format(experiment_path)) + 'Found existing experiment dir with the same name: {}. Do you wish ' + 'to resume training from this dir? [y/N]:'.format(experiment_path) + ) if resume.lower() != 'y': sys.exit() @@ -83,16 +86,18 @@ def get_log_dir( return experiment_path -def write_hparams(hparams: spec.Hyperparameters, - tuning_dir: str) -> spec.Hyperparameters: +def write_hparams( + hparams: spec.Hyperparameters, tuning_dir: str +) -> spec.Hyperparameters: hparams_file_name = os.path.join(tuning_dir, 'hparams.json') if os.path.exists(hparams_file_name): # If hparams.json already exist, use the previously saved hyperparameters. logging.info('Loading hparams from %s.', hparams_file_name) with open(hparams_file_name, 'r') as f: hparams_dict = json.load(f) - hparams = collections.namedtuple('Hyperparameters', - hparams_dict)(**hparams_dict) + hparams = collections.namedtuple('Hyperparameters', hparams_dict)( + **hparams_dict + ) else: logging.info('Saving hparams to %s.', hparams_file_name) if RANK == 0: @@ -108,8 +113,8 @@ def write_json(name: str, log_dict: Dict, indent: int = 2) -> None: def write_to_csv( - metrics: Dict, - csv_path: str, + metrics: Dict, + csv_path: str, ) -> None: try: with open(csv_path, 'r') as csv_file: @@ -118,8 +123,10 @@ def write_to_csv( except (pd.errors.EmptyDataError, FileNotFoundError) as e: measurements = pd.DataFrame([metrics], columns=sorted(metrics.keys())) if isinstance(e, pd.errors.EmptyDataError): - logging.info('Measurements file is empty. Create a new one, starting ' - 'with metrics from this step.') + logging.info( + 'Measurements file is empty. Create a new one, starting ' + 'with metrics from this step.' + ) with open(csv_path, 'w') as csv_file: measurements.to_csv(csv_file, index=False) return @@ -130,7 +137,8 @@ def _get_utilization() -> Dict: # CPU util_data['cpu.util.avg_percent_since_last'] = psutil.cpu_percent( - interval=None) # non-blocking (cpu util percentage since last call) + interval=None + ) # non-blocking (cpu util percentage since last call) util_data['cpu.freq.current'] = psutil.cpu_freq().current # Memory @@ -190,7 +198,7 @@ def _get_system_hardware_info() -> Dict: try: system_hardware_info['cpu_model_name'] = _get_cpu_model_name() system_hardware_info['cpu_count'] = psutil.cpu_count() - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info('Unable to record cpu information. Continuing without it.') gpus = GPUtil.getGPUs() @@ -199,7 +207,7 @@ def _get_system_hardware_info() -> Dict: system_hardware_info['gpu_model_name'] = gpus[0].name system_hardware_info['gpu_count'] = len(gpus) system_hardware_info['gpu_driver'] = gpus[0].driver - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info('Unable to record gpu information. Continuing without it.') return system_hardware_info @@ -208,11 +216,14 @@ def _get_system_hardware_info() -> Dict: def _get_system_software_info() -> Dict: system_software_info = {} - system_software_info['os_platform'] = \ - platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' - system_software_info['python_version'] = platform.python_version( + system_software_info['os_platform'] = ( + platform.platform() + ) # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' + system_software_info['python_version'] = ( + platform.python_version() ) # Ex. '3.11.10' - system_software_info['python_compiler'] = platform.python_compiler( + system_software_info['python_compiler'] = ( + platform.python_compiler() ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive @@ -221,26 +232,35 @@ def _get_system_software_info() -> Dict: system_software_info['git_commit_hash'] = _get_git_commit_hash() # Note: do not store git repo url as it may be sensitive or contain a # secret. - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info('Unable to record git information. Continuing without it.') return system_software_info def _get_git_commit_hash() -> str: - return subprocess.check_output(['git', 'rev-parse', - 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(['git', 'rev-parse', 'HEAD']) + .decode('ascii') + .strip() + ) def _get_git_branch() -> str: - return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', - 'HEAD']).decode('ascii').strip() + return ( + subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + .decode('ascii') + .strip() + ) def _get_cpu_model_name() -> str: output = subprocess.check_output(['lscpu']).decode('ascii').strip() - return re.findall(r'(?=Model name:\s{1,}).*', - output)[0].split('Model name:')[1].strip() + return ( + re.findall(r'(?=Model name:\s{1,}).*', output)[0] + .split('Model name:')[1] + .strip() + ) def _is_primitive_type(item: Any) -> bool: @@ -252,23 +272,25 @@ def _get_workload_properties(workload: spec.Workload) -> Dict: workload_properties = {} skip_list = ['param_shapes', 'model_params_types'] keys = [ - key for key in dir(workload) - if not key.startswith('_') and key not in skip_list + key + for key in dir(workload) + if not key.startswith('_') and key not in skip_list ] for key in keys: try: attr = getattr(workload, key) - except: # pylint: disable=bare-except + except: # noqa: E722 logging.info( - f'Unable to record workload.{key} information. Continuing without it.' + f'Unable to record workload.{key} information. Continuing without it.' ) if _is_primitive_type(attr): workload_properties[f'workload.{key}'] = attr return workload_properties -def get_meta_data(workload: spec.Workload, - rng_seed: Optional[int] = None) -> Dict: +def get_meta_data( + workload: spec.Workload, rng_seed: Optional[int] = None +) -> Dict: meta_data = {} workload_properties = _get_workload_properties(workload) meta_data.update(workload_properties) @@ -290,12 +312,14 @@ class MetricLogger(object): the wrong time. """ - def __init__(self, - csv_path: str, - eval_csv_path: str, - events_dir: Optional[str] = None, - configs: Optional[flags.FLAGS] = None, - hyperparameters: Optional[spec.Hyperparameters] = None) -> None: + def __init__( + self, + csv_path: str, + eval_csv_path: str, + events_dir: Optional[str] = None, + configs: Optional[flags.FLAGS] = None, + hyperparameters: Optional[spec.Hyperparameters] = None, + ) -> None: self._measurements = {} self._csv_path = csv_path self._eval_csv_path = eval_csv_path @@ -305,15 +329,18 @@ def __init__(self, self._tb_metric_writer = metric_writers.create_default_writer(events_dir) if wandb is not None and self.use_wandb: wandb.init( - dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework]) + dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework] + ) wandb.config.update(configs) wandb.config.update(hyperparameters._asdict()) - def append_scalar_metrics(self, - metrics: Dict, - global_step: int, - preemption_count: Optional[int] = None, - is_eval: bool = False) -> None: + def append_scalar_metrics( + self, + metrics: Dict, + global_step: int, + preemption_count: Optional[int] = None, + is_eval: bool = False, + ) -> None: metrics['global_step'] = global_step if preemption_count is not None: metrics['preemption_count'] = preemption_count @@ -324,7 +351,8 @@ def append_scalar_metrics(self, if self._tb_metric_writer: self._tb_metric_writer.write_scalars( - step=int(metrics['global_step']), scalars=metrics) + step=int(metrics['global_step']), scalars=metrics + ) self._tb_metric_writer.flush() if wandb is not None and self.use_wandb: @@ -335,15 +363,16 @@ def finish(self) -> None: wandb.finish() -def set_up_loggers(train_dir: str, - configs: flags.FLAGS, - hyperparameters: spec.Hyperparameters) -> MetricLogger: +def set_up_loggers( + train_dir: str, configs: flags.FLAGS, hyperparameters: spec.Hyperparameters +) -> MetricLogger: csv_path = os.path.join(train_dir, 'measurements.csv') eval_csv_path = os.path.join(train_dir, 'eval_measurements.csv') metrics_logger = MetricLogger( - csv_path=csv_path, - eval_csv_path=eval_csv_path, - events_dir=train_dir, - configs=configs, - hyperparameters=hyperparameters) + csv_path=csv_path, + eval_csv_path=eval_csv_path, + events_dir=train_dir, + configs=configs, + hyperparameters=hyperparameters, + ) return metrics_logger diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py index 05d882404..908ef0f27 100644 --- a/algoperf/param_utils.py +++ b/algoperf/param_utils.py @@ -14,7 +14,8 @@ def pytorch_param_shapes(model: nn.Module) -> Dict[str, spec.ShapeTuple]: def pytorch_param_types( - param_shapes: Dict[str, spec.ShapeTuple]) -> Dict[str, spec.ParameterType]: + param_shapes: Dict[str, spec.ShapeTuple], +) -> Dict[str, spec.ParameterType]: param_types = {} for name in param_shapes.keys(): if 'bn' in name: @@ -65,18 +66,21 @@ def pytorch_param_types( def jax_param_shapes( - params: spec.ParameterContainer) -> spec.ParameterShapeTree: + params: spec.ParameterContainer, +) -> spec.ParameterShapeTree: return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params) -def jax_param_types(param_shapes: spec.ParameterShapeTree, - parent_name: str = '') -> Dict[str, spec.ParameterType]: +def jax_param_types( + param_shapes: spec.ParameterShapeTree, parent_name: str = '' +) -> Dict[str, spec.ParameterType]: param_types = {} for name, value in param_shapes.items(): name = name.lower() if isinstance(value, dict) or isinstance(value, flax.core.FrozenDict): param_types[name] = jax_param_types( - value, parent_name=parent_name + '/' + name) + value, parent_name=parent_name + '/' + name + ) else: if 'batchnorm' in parent_name or 'bn' in parent_name: if name == 'scale': @@ -85,7 +89,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, param_types[name] = spec.ParameterType.BATCH_NORM_BIAS else: raise ValueError( - f'Unrecognized batch norm parameter: {parent_name}/{name}.') + f'Unrecognized batch norm parameter: {parent_name}/{name}.' + ) elif 'layernorm' in parent_name or 'ln' in parent_name: if name == 'scale': param_types[name] = spec.ParameterType.LAYER_NORM_SCALE @@ -93,7 +98,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, param_types[name] = spec.ParameterType.LAYER_NORM_BIAS else: raise ValueError( - f'Unrecognized layer norm parameter: {parent_name}/{name}.') + f'Unrecognized layer norm parameter: {parent_name}/{name}.' + ) elif 'conv' in parent_name: if 'bias' in name: param_types[name] = spec.ParameterType.BIAS @@ -102,8 +108,9 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, # Note that this is exact equality, not contained in, because # flax.linen.Embed names the embedding parameter "embedding" # https://github.com/google/flax/blob/main/flax/linen/linear.py#L604. - elif ('embedding' in name or - ('embedding' in parent_name and name == 'kernel')): + elif 'embedding' in name or ( + 'embedding' in parent_name and name == 'kernel' + ): param_types[name] = spec.ParameterType.EMBEDDING elif 'attention' in parent_name: if name == 'bias': @@ -122,7 +129,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree, param_types[name] = spec.ParameterType.ATTENTION_QKV else: raise ValueError( - f'Unrecognized attention parameter: {parent_name}/{name}.') + f'Unrecognized attention parameter: {parent_name}/{name}.' + ) elif 'bias' in name: param_types[name] = spec.ParameterType.BIAS else: diff --git a/algoperf/profiler.py b/algoperf/profiler.py index fa2a1bee2..534a5ccfb 100644 --- a/algoperf/profiler.py +++ b/algoperf/profiler.py @@ -4,10 +4,10 @@ https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers. """ -from collections import defaultdict -from contextlib import contextmanager import os import time +from collections import defaultdict +from contextlib import contextmanager from typing import Dict, Generator, List, Optional, Tuple import numpy as np @@ -21,7 +21,6 @@ def _get_monotonic_time() -> float: class Profiler: - def __init__(self, local_rank: Optional[int] = None) -> None: self._local_rank = local_rank @@ -41,7 +40,8 @@ def start(self, action_name: str) -> None: pass if action_name in self.current_actions: raise ValueError( - f'Attempted to start {action_name} which has already started.') + f'Attempted to start {action_name} which has already started.' + ) self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: @@ -49,8 +49,10 @@ def stop(self, action_name: str) -> None: pass end_time = _get_monotonic_time() if action_name not in self.current_actions: - raise ValueError(f'Attempting to stop recording an action ' - f'({action_name}) which was never started.') + raise ValueError( + f'Attempting to stop recording an action ' + f'({action_name}) which was never started.' + ) start_time = self.current_actions.pop(action_name) duration = end_time - start_time self.recorded_durations[action_name].append(duration) @@ -64,16 +66,20 @@ def profile(self, action_name: str) -> Generator: self.stop(action_name) def _make_report( - self + self, ) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]: total_duration = _get_monotonic_time() - self.start_time - report = [(str(a), - float(np.mean(d)), - float(np.std(d)), - len(d), - float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + report = [ + ( + str(a), + float(np.mean(d)), + float(np.std(d)), + len(d), + float(np.sum(d)), + 100.0 * float(np.sum(d)) / total_duration, + ) + for a, d in self.recorded_durations.items() + ] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration @@ -92,32 +98,42 @@ def log_row(action, mean, std, num_calls, total, per): row += f' {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|' return row - header_string = log_row('Action', - 'Mean Duration (s)', - 'Std Duration (s)', - 'Num Calls', - 'Total Time (s)', - 'Percentage %') + header_string = log_row( + 'Action', + 'Mean Duration (s)', + 'Std Duration (s)', + 'Num Calls', + 'Total Time (s)', + 'Percentage %', + ) output_string_len = len(header_string.expandtabs()) sep_lines = f'{sep}{"-" * output_string_len}' output_string += sep_lines + header_string + sep_lines report, total_calls, total_duration = self._make_report() - output_string += log_row('Total', - '-----', - '-----', - f'{total_calls:}', - f'{total_duration:.5}', - '100 %') + output_string += log_row( + 'Total', + '-----', + '-----', + f'{total_calls:}', + f'{total_duration:.5}', + '100 %', + ) output_string += sep_lines - for action, mean_duration, std_duration, num_calls, \ - total_duration, duration_per in report: + for ( + action, + mean_duration, + std_duration, + num_calls, + total_duration, + duration_per, + ) in report: output_string += log_row( - action, - f'{mean_duration:.5}', - f'{std_duration:.5}', - f'{num_calls}', - f'{total_duration:.5}', - f'{duration_per:.5}', + action, + f'{mean_duration:.5}', + f'{std_duration:.5}', + f'{num_calls}', + f'{total_duration:.5}', + f'{duration_per:.5}', ) output_string += sep_lines output_string += sep @@ -125,7 +141,6 @@ def log_row(action, mean, std, num_calls, total, per): class PassThroughProfiler(Profiler): - def start(self, action_name: str) -> None: pass diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index 4a674985d..af09e67fc 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -1,18 +1,22 @@ import os from typing import Tuple -from absl import logging import jax import tensorflow as tf import torch import torch.distributed as dist +import torch.nn.functional as F +from absl import logging +from torch import Tensor, nn from algoperf import spec from algoperf.profiler import Profiler -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - BatchNorm as ConformerBatchNorm -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - BatchNorm as DeepspeechBatchNorm +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + BatchNorm as ConformerBatchNorm, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + BatchNorm as DeepspeechBatchNorm, +) def pytorch_setup() -> Tuple[bool, int, torch.device, int]: @@ -58,12 +62,13 @@ def sync_ddp_time(time: float, device: torch.device) -> float: return time_tensor.item() -def update_batch_norm_fn(module: spec.ParameterContainer, - update_batch_norm: bool) -> None: +def update_batch_norm_fn( + module: spec.ParameterContainer, update_batch_norm: bool +) -> None: bn_layers = ( - torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class. - ConformerBatchNorm, # Custom BN class for conformer model. - DeepspeechBatchNorm, # Custom BN class for deepspeech model. + torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class. + ConformerBatchNorm, # Custom BN class for conformer model. + DeepspeechBatchNorm, # Custom BN class for deepspeech model. ) if isinstance(module, bn_layers): if not update_batch_norm: @@ -77,3 +82,41 @@ def update_batch_norm_fn(module: spec.ParameterContainer, module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): module.momentum = module.momentum_backup + + +class CustomDropout(nn.Module): + """A module around torch.nn.functional.dropout.""" + + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, x: Tensor, p: float) -> Tensor: + return F.dropout(x, p, training=self.training) + + +class CustomDropout2d(nn.Module): + """A module around torch.nn.functional.dropout2d.""" + + def __init__(self): + super().__init__() + self._supports_custom_dropout = True + + def forward(self, x: Tensor, p: float) -> Tensor: + return F.dropout2d(x, p, training=self.training) + + +class SequentialWithDropout(nn.Sequential): + """Sequential of modules with dropout.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._supports_custom_dropout = True + + def forward(self, x: Tensor, p: float) -> Tensor: + for module in self: + if getattr(module, '_supports_custom_dropout', False): + x = module(x, p) + else: + x = module(x) + return x diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index a579976ad..1dc773e80 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -2,16 +2,16 @@ from typing import Any, List, Union -from absl import flags -from absl import logging import numpy as np +from absl import flags, logging try: import jax.random as jax_rng except (ImportError, ModuleNotFoundError): logging.warning( - 'Could not import jax.random for the submission runner, falling back to ' - 'numpy random_utils.') + 'Could not import jax.random for the submission runner, falling back to ' + 'numpy random_utils.' + ) jax_rng = None FLAGS = flags.FLAGS @@ -54,8 +54,9 @@ def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name def _check_jax_install() -> None: if jax_rng is None: raise ValueError( - 'Must install jax to use the jax RNG library, or use PyTorch and pass ' - '--framework=pytorch to use the Numpy version instead.') + 'Must install jax to use the jax RNG library, or use PyTorch and pass ' + '--framework=pytorch to use the Numpy version instead.' + ) def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: diff --git a/algoperf/spec.py b/algoperf/spec.py index cf4f1a14e..5f7b930af 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -5,10 +5,10 @@ import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union -from absl import logging import jax -from torch import nn import torch.nn.functional as F +from absl import logging +from torch import nn class LossType(enum.Enum): @@ -53,7 +53,6 @@ class ParameterType(enum.Enum): # Define this so that if using pytree iteration utilities, can iterate over the # model shapes pytree without iterating over the shape tuples. class ShapeTuple: - def __init__(self, shape_tuple): self.shape_tuple = shape_tuple @@ -64,19 +63,22 @@ def __eq__(self, other): return self.shape_tuple == other.shape_tuple -Shape = Union[Tuple[int], - Tuple[int, int], - Tuple[int, int, int], - Tuple[int, int, int, int], - ShapeTuple] +Shape = Union[ + Tuple[int], + Tuple[int, int], + Tuple[int, int, int], + Tuple[int, int, int, int], + ShapeTuple, +] ParameterShapeTree = Dict[str, Dict[str, Shape]] # If necessary, these can be zipped together easily given they have the same # structure, to get an iterator over pairs of leaves. ParameterKey = str # Dicts can be arbitrarily nested. -ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]], - nn.Module] +ParameterContainer = Union[ + Dict[ParameterKey, Dict[ParameterKey, Tensor]], nn.Module +] ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]] RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...] @@ -92,7 +94,6 @@ def __eq__(self, other): class Workload(metaclass=abc.ABCMeta): - def __init__(self, *args, **kwargs) -> None: del args del kwargs @@ -107,8 +108,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" @abc.abstractmethod - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: """Return whether or not the workload validation goal has been reached.""" @abc.abstractmethod @@ -117,14 +119,15 @@ def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: @abc.abstractmethod def _build_input_queue( - self, - data_rng: RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, Any]]: + self, + data_rng: RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, Any]]: """Build the input queue for the workload data. This is the only function that is NOT allowed to be called by submitters. @@ -213,8 +216,9 @@ def param_shapes(self): """The shapes of the parameters in the workload model.""" if self._param_shapes is None: raise ValueError( - 'This should not happen, workload.init_model_fn() should be called ' - 'before workload.param_shapes!') + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_shapes!' + ) return self._param_shapes @property @@ -222,8 +226,9 @@ def model_params_types(self): """The types of the parameters in the workload model.""" if self._param_types is None: raise ValueError( - 'This should not happen, workload.init_model_fn() should be called ' - 'before workload.param_types!') + 'This should not happen, workload.init_model_fn() should be called ' + 'before workload.param_types!' + ) return self._param_types @abc.abstractmethod @@ -234,10 +239,12 @@ def is_output_params(self, param_key: ParameterKey) -> bool: # Tuple[RandomState, Optional[float], Optional[float]], # ParameterContainer] @abc.abstractmethod - def init_model_fn(self, - rng: RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> ModelInitState: + def init_model_fn( + self, + rng: RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> ModelInitState: """Return (initial_params, initial_model_state).""" # ModelFn = Callable[ @@ -247,32 +254,39 @@ def init_model_fn(self, # ModelAuxiliaryState, # ForwardPassMode, # RandomState, - # bool], + # bool, + # float], # Tensor] @abc.abstractmethod - def model_fn(self, - params: ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, Tensor], - model_state: ModelAuxiliaryState, - mode: ForwardPassMode, - rng: RandomState, - update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]: + def model_fn( + self, + params: ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, Tensor], + model_state: ModelAuxiliaryState, + mode: ForwardPassMode, + rng: RandomState, + update_batch_norm: bool, + dropout_rate: float, + ) -> Tuple[Tensor, ModelAuxiliaryState]: """Return logits_batch""" # Possible side effect of updating BN. - def output_activation_fn(self, logits_batch: Tensor, - framework: str) -> Tensor: + def output_activation_fn( + self, logits_batch: Tensor, framework: str + ) -> Tensor: """Turn logits into probabilities, according to the loss_type property.""" if framework not in ['pytorch', 'jax']: raise ValueError( - f'`framework` has to be either `pytorch` or `jax`, got {framework}.') + f'`framework` has to be either `pytorch` or `jax`, got {framework}.' + ) activation_fn = { - LossType.MEAN_SQUARED_ERROR: lambda z: z, - LossType.MEAN_ABSOLUTE_ERROR: lambda z: z, + LossType.MEAN_SQUARED_ERROR: lambda z: z, + LossType.MEAN_ABSOLUTE_ERROR: lambda z: z, } is_pytorch = framework == 'pytorch' # If False, framework == 'jax'. softmax_fn = ( - functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax) + functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax + ) sigmoid_fn = F.sigmoid if is_pytorch else jax.nn.sigmoid activation_fn[LossType.SOFTMAX_CROSS_ENTROPY] = softmax_fn activation_fn[LossType.SIGMOID_CROSS_ENTROPY] = sigmoid_fn @@ -284,12 +298,13 @@ def output_activation_fn(self, logits_batch: Tensor, # `update_params`. @abc.abstractmethod def loss_fn( - self, - # Dense or one-hot labels, or a tuple of (tensor, padding) for speech. - label_batch: Union[Tuple[Tensor, Tensor], Tensor], - logits_batch: Union[Tuple[Tensor, Tensor], Tensor], - mask_batch: Optional[Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, Tensor]: # differentiable + self, + # Dense or one-hot labels, or a tuple of (tensor, padding) for speech. + label_batch: Union[Tuple[Tensor, Tensor], Tensor], + logits_batch: Union[Tuple[Tensor, Tensor], Tensor], + mask_batch: Optional[Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -298,48 +313,54 @@ def loss_fn( """ @abc.abstractmethod - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: ParameterContainer, - model_state: ModelAuxiliaryState, - rng: RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: ParameterContainer, + model_state: ModelAuxiliaryState, + rng: RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Evaluate the model on a given dataset split, return final scalars.""" - def eval_model(self, - global_batch_size: int, - params: ParameterContainer, - model_state: ModelAuxiliaryState, - rng: RandomState, - data_dir: str, - imagenet_v2_data_dir: Optional[str], - global_step: int) -> Dict[str, float]: + def eval_model( + self, + global_batch_size: int, + params: ParameterContainer, + model_state: ModelAuxiliaryState, + rng: RandomState, + data_dir: str, + imagenet_v2_data_dir: Optional[str], + global_step: int, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" logging.info('Evaluating on the training split.') train_metrics = self._eval_model_on_split( - split='eval_train', - num_examples=self.num_eval_train_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=data_dir, - global_step=global_step) + split='eval_train', + num_examples=self.num_eval_train_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=data_dir, + global_step=global_step, + ) eval_metrics = {'train/' + k: v for k, v in train_metrics.items()} # We always require a validation set. logging.info('Evaluating on the validation split.') validation_metrics = self._eval_model_on_split( - 'validation', - num_examples=self.num_validation_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=data_dir, - global_step=global_step) + 'validation', + num_examples=self.num_validation_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=data_dir, + global_step=global_step, + ) for k, v in validation_metrics.items(): eval_metrics['validation/' + k] = v eval_metrics['validation/num_examples'] = self.num_validation_examples @@ -348,14 +369,15 @@ def eval_model(self, if self.num_test_examples is not None: logging.info('Evaluating on the test split.') test_metrics = self._eval_model_on_split( - 'test', - num_examples=self.num_test_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir, - global_step=global_step) + 'test', + num_examples=self.num_test_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir, + global_step=global_step, + ) for k, v in test_metrics.items(): eval_metrics['test/' + k] = v eval_metrics['test/num_examples'] = self.num_test_examples @@ -372,27 +394,32 @@ class TrainingCompleteError(Exception): # Training algorithm track submission functions, to be filled in by the # submitter. -InitOptimizerFn = Callable[[ +InitOptimizerFn = Callable[ + [ Workload, ParameterContainer, ModelAuxiliaryState, Hyperparameters, - RandomState -], - OptimizerState] - - -def init_optimizer_state(workload: Workload, - model_params: ParameterContainer, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - rng: RandomState) -> OptimizerState: + RandomState, + ], + OptimizerState, +] + + +def init_optimizer_state( + workload: Workload, + model_params: ParameterContainer, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + rng: RandomState, +) -> OptimizerState: # return initial_optimizer_state pass UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState] -UpdateParamsFn = Callable[[ +UpdateParamsFn = Callable[ + [ Workload, ParameterContainer, ParameterTypeTree, @@ -404,9 +431,10 @@ def init_optimizer_state(workload: Workload, List[Tuple[int, float]], int, RandomState, - Optional[Dict[str, Any]] -], - UpdateReturn] + Optional[Dict[str, Any]], + ], + UpdateReturn, +] # Each call to this function is considered a "step". @@ -415,23 +443,26 @@ def init_optimizer_state(workload: Workload, # and if has not actually achieved the goal then it will be considered as not # achieved the goal and get an infinite time score. Most submissions will likely # wait until the next free eval and not use this functionality. -def update_params(workload: Workload, - current_param_container: ParameterContainer, - current_params_types: ParameterTypeTree, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - batch: Dict[str, Tensor], - loss_type: LossType, - optimizer_state: OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: RandomState, - train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: +def update_params( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + batch: Dict[str, Tensor], + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass -PrepareForEvalFn = Callable[[ +PrepareForEvalFn = Callable[ + [ Workload, ParameterContainer, ParameterTypeTree, @@ -441,27 +472,31 @@ def update_params(workload: Workload, OptimizerState, List[Tuple[int, float]], int, - RandomState -], - UpdateReturn] + RandomState, + ], + UpdateReturn, +] # Prepare model and optimizer for evaluation. -def prepare_for_eval(workload: Workload, - current_param_container: ParameterContainer, - current_params_types: ParameterTypeTree, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - loss_type: LossType, - optimizer_state: OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: RandomState) -> UpdateReturn: +def prepare_for_eval( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState, +) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass -DataSelectionFn = Callable[[ +DataSelectionFn = Callable[ + [ Workload, Iterator[Dict[str, Any]], OptimizerState, @@ -469,21 +504,24 @@ def prepare_for_eval(workload: Workload, LossType, Hyperparameters, int, - RandomState -], - Tuple[Tensor, Tensor]] + RandomState, + ], + Tuple[Tensor, Tensor], +] # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: Workload, - input_queue: Iterator[Dict[str, Any]], - optimizer_state: OptimizerState, - current_param_container: ParameterContainer, - model_state: ModelAuxiliaryState, - hyperparameters: Hyperparameters, - global_step: int, - rng: RandomState) -> Dict[str, Tensor]: +def data_selection( + workload: Workload, + input_queue: Iterator[Dict[str, Any]], + optimizer_state: OptimizerState, + current_param_container: ParameterContainer, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + global_step: int, + rng: RandomState, +) -> Dict[str, Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..7fbc95bc6 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -8,22 +8,24 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np -def preprocess_for_train(image: spec.Tensor, - rng: spec.RandomState, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - crop_size: int, - padding_size: int, - dtype: tf.DType = tf.float32) -> spec.Tensor: +def preprocess_for_train( + image: spec.Tensor, + rng: spec.RandomState, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + crop_size: int, + padding_size: int, + dtype: tf.DType = tf.float32, +) -> spec.Tensor: """Preprocesses the given image for training. Args: @@ -44,20 +46,23 @@ def preprocess_for_train(image: spec.Tensor, flip_rng = rng[1, :] image_shape = tf.shape(image) - image = tf.image.resize_with_crop_or_pad(image, - image_shape[0] + padding_size, - image_shape[1] + padding_size) + image = tf.image.resize_with_crop_or_pad( + image, image_shape[0] + padding_size, image_shape[1] + padding_size + ) image = tf.image.stateless_random_crop( - image, (crop_size, crop_size, 3), seed=crop_rng) + image, (crop_size, crop_size, 3), seed=crop_rng + ) image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng) image = normalize_image(image, mean_rgb, stddev_rgb, dtype=dtype) return image -def preprocess_for_eval(image: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - dtype: tf.DType = tf.float32) -> spec.Tensor: +def preprocess_for_eval( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + dtype: tf.DType = tf.float32, +) -> spec.Tensor: """Preprocesses the given image for evaluation. Args: @@ -74,10 +79,12 @@ def preprocess_for_eval(image: spec.Tensor, return image -def normalize_image(image: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - dtype=tf.float32) -> spec.Tensor: +def normalize_image( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + dtype=tf.float32, +) -> spec.Tensor: image = tf.image.convert_image_dtype(image, dtype) image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) @@ -85,17 +92,17 @@ def normalize_image(image: spec.Tensor, def create_split( - split: str, - dataset_builder: tfds.core.dataset_builder.DatasetBuilder, - rng: spec.RandomState, - global_batch_size: int, - train: bool, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - cache: bool = False, - repeat_final_dataset: bool = False, - crop_size: int = 32, - padding_size: int = 4, + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + train: bool, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + cache: bool = False, + repeat_final_dataset: bool = False, + crop_size: int = 32, + padding_size: int = 4, ) -> Iterator[Dict[str, spec.Tensor]]: """Creates a split from the CIFAR-10 dataset using TensorFlow Datasets.""" shuffle_rng, preprocess_rng = jax.random.split(rng, 2) @@ -104,14 +111,17 @@ def preprocess_example(example_index, example): dtype = tf.float32 if train: per_step_preprocess_rng = tf.random.experimental.stateless_fold_in( - tf.cast(preprocess_rng, tf.int64), example_index) - image = preprocess_for_train(example['image'], - per_step_preprocess_rng, - mean_rgb, - stddev_rgb, - crop_size, - padding_size, - dtype) + tf.cast(preprocess_rng, tf.int64), example_index + ) + image = preprocess_for_train( + example['image'], + per_step_preprocess_rng, + mean_rgb, + stddev_rgb, + crop_size, + padding_size, + dtype, + ) else: image = preprocess_for_eval(example['image'], mean_rgb, stddev_rgb, dtype) return {'inputs': image, 'targets': example['label']} @@ -132,7 +142,8 @@ def preprocess_example(example_index, example): # index that we can fold into the RNG seed. ds = ds.enumerate() ds = ds.map( - preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) + preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) ds = ds.batch(global_batch_size, drop_remainder=train) if repeat_final_dataset: @@ -144,32 +155,36 @@ def preprocess_example(example_index, example): def create_input_iter( - split: str, - dataset_builder: tfds.core.dataset_builder.DatasetBuilder, - rng: spec.RandomState, - global_batch_size: int, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - crop_size: int, - padding_size: int, - train: bool, - cache: bool, - repeat_final_dataset: bool) -> Iterator[Dict[str, spec.Tensor]]: + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + crop_size: int, + padding_size: int, + train: bool, + cache: bool, + repeat_final_dataset: bool, +) -> Iterator[Dict[str, spec.Tensor]]: ds = create_split( - split, - dataset_builder, - rng, - global_batch_size, - train=train, - mean_rgb=mean_rgb, - stddev_rgb=stddev_rgb, - cache=cache, - repeat_final_dataset=repeat_final_dataset, - crop_size=crop_size, - padding_size=padding_size) + split, + dataset_builder, + rng, + global_batch_size, + train=train, + mean_rgb=mean_rgb, + stddev_rgb=stddev_rgb, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + crop_size=crop_size, + padding_size=padding_size, + ) it = map( - functools.partial( - shard_and_maybe_pad_np, global_batch_size=global_batch_size), - ds) + functools.partial( + shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 957079272..95238c997 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -7,8 +7,8 @@ import functools from typing import Any, Callable, Tuple -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock @@ -25,48 +25,52 @@ class ResNet(nn.Module): act: Callable = nn.relu @nn.compact - def __call__(self, - x: spec.Tensor, - update_batch_norm: bool = True, - use_running_average_bn: bool = None) -> spec.Tensor: + def __call__( + self, + x: spec.Tensor, + update_batch_norm: bool = True, + use_running_average_bn: bool = None, + ) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm norm = functools.partial( - nn.BatchNorm, - use_running_average=use_running_average_bn, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype) + nn.BatchNorm, + use_running_average=use_running_average_bn, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = conv( - self.num_filters, (3, 3), (1, 1), - padding=[(1, 1), (1, 1)], - name='Conv_init')( - x) + self.num_filters, + (3, 3), + (1, 1), + padding=[(1, 1), (1, 1)], + name='Conv_init', + )(x) x = norm(name='BatchNorm_init')(x) x = nn.relu(x) for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls( - self.num_filters * 2**i, - strides=strides, - conv=conv, - norm=norm, - act=self.act)( - x) + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + )(x) x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - dtype=self.dtype)( - x) + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + )(x) return x ResNet18 = functools.partial( - ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) + ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock +) diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..bc26e3899 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,32 +3,30 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils -from flax import linen as nn -from flax.core import pop import jax -from jax import lax import jax.numpy as jnp import optax import tensorflow_datasets as tfds +from flax import jax_utils +from flax import linen as nn +from flax.core import pop +from jax import lax -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter from algoperf.workloads.cifar.workload import BaseCifarWorkload class CifarWorkload(BaseCifarWorkload): - def _build_cifar_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) train = split == 'train' @@ -38,38 +36,38 @@ def _build_cifar_dataset( elif split == 'validation': split = f'train[{self.num_train_examples}:]' ds = create_input_iter( - split, - ds_builder, - data_rng, - batch_size, - self.train_mean, - self.train_stddev, - self.crop_size, - self.padding_size, - train=train, - cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset) + split, + ds_builder, + data_rng, + batch_size, + self.train_mean, + self.train_stddev, + self.crop_size, + self.padding_size, + train=train, + cache=not train if cache is None else cache, + repeat_final_dataset=repeat_final_dataset, + ) return ds def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches - return self._build_cifar_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset) + return self._build_cifar_dataset( + data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset + ) def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: + self, model_state: spec.ModelAuxiliaryState + ) -> spec.ModelAuxiliaryState: """Sync the batch statistics across replicas.""" # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics @@ -79,20 +77,15 @@ def sync_batch_stats( new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate model_cls = getattr(models, 'ResNet18') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) - variables = jax.jit(model.init)({'params': rng}, - jnp.ones(input_shape, model.dtype)) + variables = jax.jit(model.init)( + {'params': rng}, jnp.ones(input_shape, model.dtype) + ) model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -104,43 +97,46 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn, + ) return logits, new_model_state else: logits = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=False, - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=False, + use_running_average_bn=use_running_average_bn, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -150,7 +146,8 @@ def loss_fn( one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( - smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + smoothed_targets * nn.log_softmax(logits_batch), axis=-1 + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -159,51 +156,53 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def _compute_metrics(self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor) -> Dict[str, spec.Tensor]: + def _compute_metrics( + self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor + ) -> Dict[str, spec.Tensor]: summed_loss = self.loss_fn(labels, logits, weights)['summed'] # Number of correct predictions. accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, + 'loss': summed_loss, + 'accuracy': accuracy, } metrics = lax.psum(metrics, axis_name='batch') return metrics @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) return self._compute_metrics(logits, batch['targets'], weights) def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index e6a7a8a81..0e08f5c5a 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -12,23 +12,24 @@ from algoperf import spec from algoperf.init_utils import pytorch_default_init -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - BasicBlock -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - Bottleneck -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + BasicBlock, + Bottleneck, + conv1x1, +) class ResNet(nn.Module): - - def __init__(self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 10, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 10, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -42,21 +43,26 @@ def __init__(self, replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( - 'replace_stride_with_dilation should be None ' - f'or a 3-element tuple, got {replace_stride_with_dilation}') + 'replace_stride_with_dilation should be None ' + f'or a 3-element tuple, got {replace_stride_with_dilation}' + ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( - block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) self.layer3 = self._make_layer( - block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) self.layer4 = self._make_layer( - block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) self.fc = nn.Linear(512 * block.expansion, num_classes) self.reset_parameters() @@ -68,7 +74,7 @@ def reset_parameters(self) -> None: nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.normal_(self.fc.weight, std=1e-2) - nn.init.constant_(self.fc.bias, 0.) + nn.init.constant_(self.fc.bias, 0.0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, @@ -81,12 +87,14 @@ def reset_parameters(self) -> None: elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def _make_layer(self, - block: Type[Union[BasicBlock, Bottleneck]], - planes: int, - blocks: int, - stride: int = 1, - dilate: bool = False) -> nn.Sequential: + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -95,32 +103,39 @@ def _make_layer(self, stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = torch.nn.Sequential( - collections.OrderedDict([ - ("conv", conv1x1(self.inplanes, planes * block.expansion, - stride)), - ("bn", norm_layer(planes * block.expansion)), - ])) + collections.OrderedDict( + [ + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ('bn', norm_layer(planes * block.expansion)), + ] + ) + ) layers = [] layers.append( - block(self.inplanes, - planes, - stride, - downsample, - self.groups, - self.base_width, - previous_dilation, - norm_layer)) + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer)) + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) return nn.Sequential(*layers) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index d05131c27..f1189bebc 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -12,10 +12,7 @@ from torchvision import transforms from torchvision.datasets import CIFAR10 -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec +from algoperf import data_utils, param_utils, pytorch_utils, spec from algoperf.workloads.cifar.cifar_pytorch.models import resnet18 from algoperf.workloads.cifar.workload import BaseCifarWorkload @@ -23,7 +20,6 @@ class CifarWorkload(BaseCifarWorkload): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Is set in submission_runner.py for workloads with PyTorch evaluation @@ -34,7 +30,8 @@ def __init__(self, *args, **kwargs) -> None: def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -42,47 +39,54 @@ def eval_num_workers(self, eval_num_workers: int): self._eval_num_workers = eval_num_workers def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, ) -> torch.utils.data.DataLoader: del cache del repeat_final_dataset is_train = split == 'train' - normalize = transforms.Compose([ + normalize = transforms.Compose( + [ transforms.ToTensor(), transforms.Normalize(mean=self.train_mean, std=self.train_stddev), - ]) + ] + ) eval_transform_config = normalize - train_transform_config = transforms.Compose([ + train_transform_config = transforms.Compose( + [ transforms.RandomCrop( - size=self.crop_size, - padding=self.padding_size, + size=self.crop_size, + padding=self.padding_size, ), transforms.RandomHorizontalFlip(), normalize, - ]) + ] + ) transform = train_transform_config if is_train else eval_transform_config dataset = CIFAR10( - root=data_dir, - train=split in ['train', 'eval_train', 'validation'], - download=False, - transform=transform) + root=data_dir, + train=split in ['train', 'eval_train', 'validation'], + download=False, + transform=transform, + ) assert self.num_train_examples + self.num_validation_examples == 50000 indices = list(range(50000)) indices_split = { - 'train': indices[:self.num_train_examples], - 'validation': indices[self.num_train_examples:], + 'train': indices[: self.num_train_examples], + 'validation': indices[self.num_train_examples :], } if split == 'eval_train': train_indices = indices_split['train'] random.Random(int(data_rng[0])).shuffle(train_indices) - indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] + indices_split['eval_train'] = train_indices[ + : self.num_eval_train_examples + ] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) @@ -92,30 +96,34 @@ def _build_dataset( ds_iter_batch_size = per_device_batch_size if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True + ) else: sampler = data_utils.DistributedEvalSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False + ) else: ds_iter_batch_size = global_batch_size dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, - pin_memory=True, - drop_last=is_train) + dataset, + batch_size=ds_iter_batch_size, + shuffle=not USE_PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=4 if is_train else self.eval_num_workers, + pin_memory=True, + drop_last=is_train, + ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) return dataloader def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate @@ -143,30 +151,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['fc.weight', 'fc.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng model = params if mode == spec.ForwardPassMode.EVAL: if update_batch_norm: raise ValueError( - 'Batch norm statistics cannot be updated during evaluation.') + 'Batch norm statistics cannot be updated during evaluation.' + ) model.eval() if mode == spec.ForwardPassMode.TRAIN: model.train() model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm, + ) + ) contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) @@ -175,11 +187,12 @@ def model_fn( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -187,10 +200,11 @@ def loss_fn( (not synced across devices). """ per_example_losses = F.cross_entropy( - logits_batch, - label_batch, - reduction='none', - label_smoothing=label_smoothing) + logits_batch, + label_batch, + reduction='none', + label_smoothing=label_smoothing, + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -199,25 +213,27 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: @@ -229,8 +245,8 @@ def _eval_model( return {'accuracy': accuracy, 'loss': summed_loss} def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index c0d565108..31636807c 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -7,15 +7,14 @@ import jax import torch +import algoperf.random_utils as prng from algoperf import spec from algoperf.pytorch_utils import pytorch_setup -import algoperf.random_utils as prng USE_PYTORCH_DDP, _, _, _ = pytorch_setup() class BaseCifarWorkload(spec.Workload): - _num_classes: int = 10 @property @@ -23,8 +22,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'accuracy' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/accuracy'] > self.validation_target_value @property @@ -51,8 +51,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -93,37 +94,35 @@ def eval_period_time_sec(self) -> int: return 600 # 10 mins. def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, ) -> Iterator[Dict[str, spec.Tensor]]: raise NotImplementedError def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches if split == 'test': if not cache: raise ValueError('cache must be True for split=test.') if not repeat_final_dataset: raise ValueError('repeat_final_dataset must be True for split=test.') - return self._build_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset) + return self._build_dataset( + data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset + ) @property def step_hint(self) -> int: @@ -133,39 +132,43 @@ def step_hint(self) -> int: return 4883 def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: raise NotImplementedError @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - cache=True, - repeat_final_dataset=True) + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + cache=True, + repeat_final_dataset=True, + ) num_batches = int(math.ceil(num_examples / global_batch_size)) num_devices = max(torch.cuda.device_count(), jax.local_device_count()) @@ -174,10 +177,9 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) per_device_model_rngs = prng.split(model_rng, num_devices) # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - per_device_model_rngs) + synced_metrics = self._eval_model( + params, batch, model_state, per_device_model_rngs + ) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py index 6d9a489ff..706c2b51a 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py @@ -3,8 +3,12 @@ from typing import Sequence import flax.linen as nn -from jax import nn as jnn import jax.numpy as jnp +from jax import nn as jnn + +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.0 class DLRMResNet(nn.Module): @@ -23,12 +27,12 @@ class DLRMResNet(nn.Module): mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) embed_dim: int = 128 - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE use_layer_norm: bool = False # Unused. embedding_init_multiplier: float = None # Unused @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -36,20 +40,18 @@ def __call__(self, x, train): mlp_bottom_dims = self.mlp_bottom_dims bot_mlp_input = nn.Dense( - mlp_bottom_dims[0], - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5), - )( - bot_mlp_input) + mlp_bottom_dims[0], + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0] ** 0.5), + )(bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) for dense_dim in mlp_bottom_dims[1:]: x = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), - )( - bot_mlp_input) + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), + )(bot_mlp_input) bot_mlp_input += nn.relu(x) base_init_fn = jnn.initializers.uniform(scale=1.0) @@ -59,46 +61,51 @@ def __call__(self, x, train): def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) + embedding_table = self.param( + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] - embed_features = jnp.reshape(embed_features, - (batch_size, 26 * self.embed_dim)) + embed_features = jnp.reshape( + embed_features, (batch_size, 26 * self.embed_dim) + ) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) top_mlp_input = nn.Dense( - mlp_top_dims[0], - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( - top_mlp_input) + mlp_top_dims[0], + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0])) + ), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / mlp_top_dims[0])), + )(top_mlp_input) top_mlp_input = nn.relu(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: fan_in = mlp_top_dims[layer_idx - 1] x = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( - top_mlp_input) + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out)) + ), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx]) + ), + )(top_mlp_input) x = nn.relu(x) - if self.dropout_rate and layer_idx == num_layers_top - 2: - x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + if dropout_rate and layer_idx == num_layers_top - 2: + x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. logits = nn.Dense( - 1, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( - top_mlp_input) + 1, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1)) + ), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)), + )(top_mlp_input) return logits @@ -114,16 +121,18 @@ def dot_interact(concat_features): batch_size = concat_features.shape[0] # Interact features, select upper or lower-triangular portion, and reshape. - xactions = jnp.matmul(concat_features, - jnp.transpose(concat_features, [0, 2, 1])) + xactions = jnp.matmul( + concat_features, jnp.transpose(concat_features, [0, 2, 1]) + ) feature_dim = xactions.shape[-1] indices = jnp.array(jnp.triu_indices(feature_dim)) num_elems = indices.shape[1] indices = jnp.tile(indices, [1, batch_size]) indices0 = jnp.reshape( - jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), - [1, -1]) + jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), + [1, -1], + ) indices = tuple(jnp.concatenate((indices0, indices), 0)) activations = xactions[indices] activations = jnp.reshape(activations, [batch_size, -1]) @@ -151,25 +160,25 @@ class DlrmSmall(nn.Module): embedding_init_multiplier: float = None @nn.compact - def __call__(self, x, train): + def __call__(self, x, train, dropout_rate=DROPOUT_RATE): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) # Bottom MLP. for dense_dim in self.mlp_bottom_dims: bot_mlp_input = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )( - bot_mlp_input) + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + )(bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) if self.use_layer_norm: bot_mlp_input = nn.LayerNorm()(bot_mlp_input) bot_mlp_output = bot_mlp_input batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) + feature_stack = jnp.reshape( + bot_mlp_output, [batch_size, -1, self.embed_dim] + ) # Embedding table look-up. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size @@ -182,38 +191,45 @@ def __call__(self, x, train): def scaled_init(key, shape, dtype=jnp.float_): return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) + embedding_table = self.param( + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) idx_lookup = jnp.reshape(idx_lookup, [-1]) embed_features = embedding_table[idx_lookup] - embed_features = jnp.reshape(embed_features, - [batch_size, -1, self.embed_dim]) + embed_features = jnp.reshape( + embed_features, [batch_size, -1, self.embed_dim] + ) if self.use_layer_norm: embed_features = nn.LayerNorm()(embed_features) feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) + top_mlp_input = jnp.concatenate( + [bot_mlp_output, dot_interact_output], axis=-1 + ) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) for layer_idx, fan_out in enumerate(mlp_top_dims): fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] top_mlp_input = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( - top_mlp_input) + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out)) + ), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)), + )(top_mlp_input) if layer_idx < (num_layers_top - 1): top_mlp_input = nn.relu(top_mlp_input) if self.use_layer_norm: top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) + if ( + dropout_rate is not None + and dropout_rate > 0.0 + and layer_idx == num_layers_top - 2 + ): + top_mlp_input = Dropout(dropout_rate, deterministic=not train)( + top_mlp_input, rate=dropout_rate + ) logits = top_mlp_input return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..283b3be8e 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -3,26 +3,24 @@ import functools from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import numpy as np +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models -from algoperf.workloads.criteo1tb.workload import \ - BaseCriteo1TbDlrmSmallWorkload +from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): - @property def eval_batch_size(self) -> int: return 131_072 def _per_example_sigmoid_binary_cross_entropy( - self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor: + self, logits: spec.Tensor, targets: spec.Tensor + ) -> spec.Tensor: """Computes the sigmoid binary cross entropy per example. Args: @@ -39,11 +37,12 @@ def _per_example_sigmoid_binary_cross_entropy( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense (not one-hot) labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense (not one-hot) labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -55,7 +54,8 @@ def loss_fn( label_batch = jnp.reshape(label_batch, (batch_size,)) logits_batch = jnp.reshape(logits_batch, (batch_size,)) per_example_losses = self._per_example_sigmoid_binary_cross_entropy( - logits=logits_batch, targets=label_batch) + logits=logits_batch, targets=label_batch + ) if mask_batch is not None: mask_batch = jnp.reshape(mask_batch, (batch_size,)) per_example_losses *= mask_batch @@ -64,35 +64,33 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - tabulate: Optional[bool] = False, + self, + rng: spec.RandomState, + tabulate: Optional[bool] = False, ) -> spec.ModelInitState: """Only dropout is used.""" - del aux_dropout_rate if self.use_resnet: model_class = models.DLRMResNet else: model_class = models.DlrmSmall + self._model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) - - params_rng, dropout_rng = jax.random.split(rng) + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier, + ) + + params_rng, _ = jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 num_dense_features = 13 @@ -100,8 +98,11 @@ def init_model_fn( input_shape = (init_fake_batch_size, input_size) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( - {'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32)) + { + 'params': params_rng, + }, + jnp.ones(input_shape, jnp.float32), + ) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -111,13 +112,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch['inputs'] @@ -125,39 +128,43 @@ def model_fn( apply_kwargs = {'train': train} if train: apply_kwargs['rngs'] = {'dropout': rng} + apply_kwargs['dropout_rate'] = dropout_rate logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs) return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0), + static_broadcasted_argnums=(0,), + ) + def _eval_batch_pmapped( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> spec.Tensor: logits, _ = self.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) summed_loss = self.loss_fn( - label_batch=batch['targets'], logits_batch=logits, - mask_batch=weights)['summed'] + label_batch=batch['targets'], logits_batch=logits, mask_batch=weights + )['summed'] return summed_loss - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + def _eval_batch( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64 + ) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): @@ -165,7 +172,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use LayerNorm in the model.""" @@ -199,7 +205,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): - @property def validation_target_value(self) -> float: return 0.129657 diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py index 7a40f0e81..1906bf7ae 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,9 +5,13 @@ import torch from torch import nn +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout + +DROPOUT_RATE = 0.0 + class DenseBlock(nn.Module): - """Dense block with optional residual connection.""" "" + """Dense block with optional residual connection.""" '' def __init__(self, module, resnet=False): super().__init__() @@ -15,10 +19,20 @@ def __init__(self, module, resnet=False): self.resnet = resnet def forward(self, x): - if self.resnet: - return self.module(x) + x - else: - return self.module(x) + return self.module(x) + x if self.resnet else self.module(x) + + +class DenseBlockWithDropout(nn.Module): + """Dense block with optional residual connection and support for dropout.""" + + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + self._supports_custom_dropout = True + + def forward(self, x, p): + return self.module(x, p) + x if self.resnet else self.module(x, p) class DotInteract(nn.Module): @@ -26,17 +40,20 @@ class DotInteract(nn.Module): def __init__(self, num_sparse_features): super().__init__() - self.triu_indices = torch.triu_indices(num_sparse_features + 1, - num_sparse_features + 1) + self.triu_indices = torch.triu_indices( + num_sparse_features + 1, num_sparse_features + 1 + ) def forward(self, dense_features, sparse_features): - combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), - dim=1) - interactions = torch.bmm(combined_values, - torch.transpose(combined_values, 1, 2)) - interactions_flat = interactions[:, - self.triu_indices[0], - self.triu_indices[1]] + combined_values = torch.cat( + (dense_features.unsqueeze(1), sparse_features), dim=1 + ) + interactions = torch.bmm( + combined_values, torch.transpose(combined_values, 1, 2) + ) + interactions_flat = interactions[ + :, self.triu_indices[0], self.triu_indices[1] + ] return torch.cat((dense_features, interactions_flat), dim=1) @@ -51,16 +68,17 @@ class DLRMResNet(nn.Module): embed_dim: embedding dimension. """ - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, - dropout_rate=0.0, - use_layer_norm=False, - embedding_init_multiplier=None): + def __init__( + self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + use_layer_norm=False, + embedding_init_multiplier=None, + ): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -78,7 +96,8 @@ def __init__(self, scale = 1.0 / torch.sqrt(self.vocab_size) for i in range(num_chunks): chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim) + ) chunk.data.uniform_(0, 1) chunk.data = scale * chunk.data self.register_parameter(f'embedding_chunk_{i}', chunk) @@ -101,11 +120,11 @@ def __init__(self, for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) + limit = math.sqrt(6.0 / (module.in_features + module.out_features)) nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) # Number of sparse features = 26 fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] @@ -116,33 +135,34 @@ def __init__(self, block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): block.append(nn.ReLU(inplace=True)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - block.append(nn.Dropout(p=dropout_rate)) - block = nn.Sequential(*block) + if layer_idx == num_layers_top - 2: + block.append(CustomDropout()) + block = SequentialWithDropout(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = DenseBlock(block, resnet=True) + block = DenseBlockWithDropout(block, resnet=True) else: - block = DenseBlock(block) + block = DenseBlockWithDropout(block) mlp_top_blocks.append(block) fan_in = fan_out - self.top_mlp = nn.Sequential(*mlp_top_blocks) + self.top_mlp = SequentialWithDropout(*mlp_top_blocks) for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + module.weight.data, + 0.0, + math.sqrt(2.0 / (module.in_features + module.out_features)), + ) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): batch_size = x.shape[0] dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) + x, [self.num_dense_features, self.num_sparse_features], 1 + ) # Bottom MLP. embedded_dense = self.bot_mlp(dense_features) @@ -152,12 +172,13 @@ def forward(self, x): idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size embedding_table = torch.cat(self.embedding_table_chucks, dim=0) embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, 26 * self.embed_dim]) + embedded_sparse = torch.reshape( + embedded_sparse, [batch_size, 26 * self.embed_dim] + ) top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. - logits = self.top_mlp(top_mlp_input) + logits = self.top_mlp(top_mlp_input, dropout_rate) return logits @@ -172,16 +193,17 @@ class DlrmSmall(nn.Module): embed_dim: embedding dimension. """ - def __init__(self, - vocab_size, - num_dense_features=13, - num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), - embed_dim=128, - dropout_rate=0.0, - use_layer_norm=False, - embedding_init_multiplier=None): + def __init__( + self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), + embed_dim=128, + use_layer_norm=False, + embedding_init_multiplier=None, + ): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -205,7 +227,8 @@ def __init__(self, for i in range(num_chunks): chunk = nn.Parameter( - torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)) + torch.Tensor(self.vocab_size // num_chunks, self.embed_dim) + ) chunk.data.uniform_(0, 1) chunk.data = scale * chunk.data self.register_parameter(f'embedding_chunk_{i}', chunk) @@ -222,30 +245,32 @@ def __init__(self, self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) + limit = math.sqrt(6.0 / (module.in_features + module.out_features)) nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + self.dot_interact = DotInteract( + num_sparse_features=num_sparse_features, + ) # TODO: Write down the formula here instead of the constant. input_dims = 506 num_layers_top = len(self.mlp_top_dims) top_mlp_layers = [] for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] + fan_in = ( + input_dims if layer_idx == 0 else self.mlp_top_dims[layer_idx - 1] + ) top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): top_mlp_layers.append(nn.ReLU(inplace=True)) if use_layer_norm: top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) - if (dropout_rate is not None and dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) + if layer_idx == num_layers_top - 2: + top_mlp_layers.append(CustomDropout()) + self.top_mlp = SequentialWithDropout(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -253,18 +278,20 @@ def __init__(self, for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + module.weight.data, + 0.0, + math.sqrt(2.0 / (module.in_features + module.out_features)), + ) + nn.init.normal_( + module.bias.data, 0.0, math.sqrt(1.0 / module.out_features) + ) - def forward(self, x): + def forward(self, x, dropout_rate=DROPOUT_RATE): batch_size = x.shape[0] dense_features, sparse_features = torch.split( - x, [self.num_dense_features, self.num_sparse_features], 1) + x, [self.num_dense_features, self.num_sparse_features], 1 + ) # Bottom MLP. embedded_dense = self.bot_mlp(dense_features) @@ -274,14 +301,16 @@ def forward(self, x): idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size embedding_table = torch.cat(self.embedding_table_chucks, dim=0) embedded_sparse = embedding_table[idx_lookup] - embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) + embedded_sparse = torch.reshape( + embedded_sparse, [batch_size, -1, self.embed_dim] + ) if self.embed_ln: embedded_sparse = self.embed_ln(embedded_sparse) # Dot product interactions. concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) + dense_features=embedded_dense, sparse_features=embedded_sparse + ) # Final MLP. - logits = self.top_mlp(concatenated_dense) + logits = self.top_mlp(concatenated_dense, dropout_rate) return logits diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 726aa8705..74f91de43 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -7,24 +7,22 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models -from algoperf.workloads.criteo1tb.workload import \ - BaseCriteo1TbDlrmSmallWorkload +from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): - @property def eval_batch_size(self) -> int: return 8_192 def _per_example_sigmoid_binary_cross_entropy( - self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor: + self, logits: spec.Tensor, targets: spec.Tensor + ) -> spec.Tensor: ls = torch.nn.LogSigmoid() log_p = ls(logits) log_not_p = ls(-logits) @@ -35,11 +33,12 @@ def _per_example_sigmoid_binary_cross_entropy( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense (not one-hot) labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense (not one-hot) labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -51,7 +50,8 @@ def loss_fn( label_batch = torch.reshape(label_batch, (batch_size,)) logits_batch = torch.reshape(logits_batch, (batch_size,)) per_example_losses = self._per_example_sigmoid_binary_cross_entropy( - logits=logits_batch, targets=label_batch) + logits=logits_batch, targets=label_batch + ) if mask_batch is not None: mask_batch = torch.reshape(mask_batch, (batch_size,)) per_example_losses *= mask_batch @@ -60,18 +60,12 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Only dropout is used.""" - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -80,14 +74,14 @@ def init_model_fn( else: model_class = models.DlrmSmall model = model_class( - vocab_size=self.vocab_size, - num_dense_features=self.num_dense_features, - mlp_bottom_dims=self.mlp_bottom_dims, - mlp_top_dims=self.mlp_top_dims, - embed_dim=self.embed_dim, - dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm, - embedding_init_multiplier=self.embedding_init_multiplier) + vocab_size=self.vocab_size, + num_dense_features=self.num_dense_features, + mlp_bottom_dims=self.mlp_bottom_dims, + mlp_top_dims=self.mlp_top_dims, + embed_dim=self.embed_dim, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -102,13 +96,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['top_mlp.4.weight', 'top_mlp.4.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -123,24 +119,25 @@ def model_fn( model.train() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - logits_batch = model(inputs) + logits_batch = model(inputs, dropout_rate=dropout_rate) return logits_batch, None def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) @@ -148,35 +145,42 @@ def _build_input_queue( # avoid creating too many threads. if RANK == 0: np_iter = super()._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset, + ) weights = None while True: if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) + batch['inputs'], dtype=torch.float32, device=DEVICE + ) targets = torch.as_tensor( - batch['targets'], dtype=torch.float32, device=DEVICE) + batch['targets'], dtype=torch.float32, device=DEVICE + ) if not_train: weights = batch.get('weights') if weights is None: - weights = torch.ones((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) + weights = torch.ones( + (N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE, + ) else: weights = torch.as_tensor( - weights, dtype=torch.float32, device=DEVICE) + weights, dtype=torch.float32, device=DEVICE + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if not_train: # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(targets[0]), dtype=torch.int32, device=DEVICE) + len(targets[0]), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) dist.broadcast(weights, src=0) weights = weights[0] @@ -192,52 +196,57 @@ def _build_input_queue( else: if not_train: # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((1,), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (1,), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) - weights = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) + weights = torch.empty( + (N_GPUS, per_device_batch_size, 1), + dtype=torch.float32, + device=DEVICE, + ) dist.broadcast(weights, src=0) weights = weights[RANK] - inputs = torch.empty((N_GPUS, per_device_batch_size, 39), - dtype=torch.float32, - device=DEVICE) + inputs = torch.empty( + (N_GPUS, per_device_batch_size, 39), + dtype=torch.float32, + device=DEVICE, + ) dist.broadcast(inputs, src=0) inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size, 1), - dtype=torch.float32, - device=DEVICE) + targets = torch.empty( + (N_GPUS, per_device_batch_size, 1), dtype=torch.float32, device=DEVICE + ) dist.broadcast(targets, src=0) targets = targets[RANK] if weights is None: weights = torch.ones(per_device_batch_size, device=DEVICE) batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights, + 'inputs': inputs, + 'targets': targets, + 'weights': weights, } yield batch - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + def _eval_batch( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> spec.Tensor: logits, _ = self.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = torch.ones(len(logits), device=DEVICE) summed_loss = self.loss_fn( - label_batch=batch['targets'], logits_batch=logits, - mask_batch=weights)['summed'] + label_batch=batch['targets'], logits_batch=logits, mask_batch=weights + )['summed'] return summed_loss.to(dtype=torch.float64) @@ -246,7 +255,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use LayerNorm in the model.""" @@ -280,7 +288,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): - @property def validation_target_value(self) -> float: return 0.129657 diff --git a/algoperf/workloads/criteo1tb/input_pipeline.py b/algoperf/workloads/criteo1tb/input_pipeline.py index 7e254336a..bce8b11c4 100644 --- a/algoperf/workloads/criteo1tb/input_pipeline.py +++ b/algoperf/workloads/criteo1tb/input_pipeline.py @@ -19,32 +19,32 @@ # Raw vocab sizes from # https://cloud.google.com/tpu/docs/tutorials/dlrm-dcn-2.x#run-model. _VOCAB_SIZES = [ - 39884406, - 39043, - 17289, - 7420, - 20263, - 3, - 7120, - 1543, - 63, - 38532951, - 2953546, - 403346, - 10, - 2208, - 11938, - 155, - 4, - 976, - 14, - 39979771, - 25641295, - 39664984, - 585935, - 12972, - 108, - 36, + 39884406, + 39043, + 17289, + 7420, + 20263, + 3, + 7120, + 1543, + 63, + 38532951, + 2953546, + 403346, + 10, + 2208, + 11938, + 155, + 4, + 976, + 14, + 39979771, + 25641295, + 39664984, + 585935, + 12972, + 108, + 36, ] @@ -60,7 +60,8 @@ def _parse_example_fn(num_dense_features, example): categorical_defaults = [['00000000'] for _ in range(len(_VOCAB_SIZES))] record_defaults = label_defaults + int_defaults + categorical_defaults fields = tf.io.decode_csv( - example, record_defaults, field_delim='\t', na_value='-1') + example, record_defaults, field_delim='\t', na_value='-1' + ) num_labels = 1 features = {} @@ -78,20 +79,24 @@ def _parse_example_fn(num_dense_features, example): # We append the column index to the string to make the same id in different # columns unique. cat_features.append( - tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx])) + tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx]) + ) cat_features = tf.cast( - tf.stack(cat_features, axis=1), dtype=int_features.dtype) + tf.stack(cat_features, axis=1), dtype=int_features.dtype + ) features['inputs'] = tf.concat([int_features, cat_features], axis=1) return features -def get_criteo1tb_dataset(split: str, - shuffle_rng, - data_dir: str, - num_dense_features: int, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): +def get_criteo1tb_dataset( + split: str, + shuffle_rng, + data_dir: str, + num_dense_features: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, +): """Get the Criteo 1TB dataset for a given split.""" num_test_files = _NUM_DAY_23_FILES // 2 + 1 if split in ['train', 'eval_train']: @@ -99,19 +104,20 @@ def get_criteo1tb_dataset(split: str, elif split == 'validation': # Assumes files are of the format day_23_04. file_paths = [ - os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') - for s in range(num_test_files, _NUM_DAY_23_FILES) + os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') + for s in range(num_test_files, _NUM_DAY_23_FILES) ] else: file_paths = [ - os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') - for s in range(0, num_test_files) + os.path.join(data_dir, f'day_23_{str(s).zfill(2)}') + for s in range(0, num_test_files) ] is_training = split == 'train' shuffle = is_training or split == 'eval_train' ds = tf.data.Dataset.list_files( - file_paths, shuffle=shuffle, seed=shuffle_rng[0]) + file_paths, shuffle=shuffle, seed=shuffle_rng[0] + ) if shuffle: ds = ds.shuffle(buffer_size=1024) @@ -132,9 +138,10 @@ def get_criteo1tb_dataset(split: str, ds = ds.repeat() ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) return ds diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 617b2e987..2cb7e5450 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -4,8 +4,8 @@ import os from typing import Dict, Iterator, Optional, Tuple -from absl import flags import torch.distributed as dist +from absl import flags from algoperf import spec from algoperf.workloads.criteo1tb import input_pipeline @@ -29,8 +29,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'loss' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/loss'] < self.validation_target_value @property @@ -71,8 +72,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -100,23 +102,25 @@ def eval_period_time_sec(self) -> int: return 2 * 60 # 2 mins. def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache ds = input_pipeline.get_criteo1tb_dataset( - split=split, - shuffle_rng=data_rng, - data_dir=data_dir, - num_dense_features=self.num_dense_features, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + split=split, + shuffle_rng=data_rng, + data_dir=data_dir, + num_dense_features=self.num_dense_features, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset, + ) for batch in iter(ds): yield batch @@ -126,15 +130,17 @@ def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" return 10_666 - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -142,12 +148,13 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng=rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=True) + data_rng=rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=True, + ) loss = 0.0 for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py index 44bff0e21..b80c370ea 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/models.py +++ b/algoperf/workloads/fastmri/fastmri_jax/models.py @@ -12,13 +12,17 @@ Data: github.com/facebookresearch/fastMRI/tree/main/fastmri/data """ + import functools -from typing import Optional import flax.linen as nn import jax import jax.numpy as jnp +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.0 + def _instance_norm2d(x, axes, epsilon=1e-5): # promote x to at least float32, this avoids half precision computation @@ -28,7 +32,7 @@ def _instance_norm2d(x, axes, epsilon=1e-5): mean2 = jnp.mean(jnp.square(x), axes) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. - var = jnp.maximum(0., mean2 - jnp.square(mean)) + var = jnp.maximum(0.0, mean2 - jnp.square(mean)) stats_shape = list(x.shape) for axis in axes: stats_shape[axis] = 1 @@ -43,39 +47,38 @@ def _instance_norm2d(x, axes, epsilon=1e-5): class UNet(nn.Module): """Jax / Flax implementation of a U-Net model. - O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks - for biomedical image segmentation. In International Conference on Medical - image computing and computer-assisted intervention, pages 234–241. - Springer, 2015. + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. - out_channels: Number of channels in the output to the U-Net model. - channels: Number of output channels of the first convolution layer. - num_pool_layers: Number of down-sampling and up-sampling layers. - dropout_rate: Dropout probability. + out_channels: Number of channels in the output to the U-Net model. + channels: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + dropout_rate: Dropout probability. """ + num_channels: int = 32 num_pool_layers: int = 4 out_channels = 1 - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: float = DROPOUT_RATE use_tanh: bool = False use_layer_norm: bool = False @nn.compact - def __call__(self, x, train=True): - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): # pylint: disable=invalid-name _ConvBlock = functools.partial( - ConvBlock, - dropout_rate=dropout_rate, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) + ConvBlock, + dropout_rate=dropout_rate, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) _TransposeConvBlock = functools.partial( - TransposeConvBlock, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm) + TransposeConvBlock, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) down_sample_layers = [_ConvBlock(self.num_channels)] @@ -126,9 +129,9 @@ def __call__(self, x, train=True): output = jnp.concatenate((output, downsample_layer), axis=-1) output = conv(output, train) - output = nn.Conv( - self.out_channels, kernel_size=(1, 1), strides=(1, 1))( - output) + output = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))( + output + ) return output.squeeze(-1) @@ -137,13 +140,14 @@ class ConvBlock(nn.Module): out_channels: Number of channels in the output. dropout_rate: Dropout probability. """ + out_channels: int - dropout_rate: float use_tanh: bool use_layer_norm: bool + dropout_rate: float = 0.0 @nn.compact - def __call__(self, x, train=True): + def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE): """Forward function. Note: Pytorch is NCHW and jax/flax is NHWC. Args: @@ -153,11 +157,11 @@ def __call__(self, x, train=True): jnp.array: Output tensor of shape `(N, H, W, out_channels)`. """ x = nn.Conv( - features=self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - use_bias=False)( - x) + features=self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + use_bias=False, + )(x) if self.use_layer_norm: x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: @@ -172,23 +176,23 @@ def __call__(self, x, train=True): x = activation_fn(x) # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW - x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate + ) x = nn.Conv( - features=self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - use_bias=False)( - x) + features=self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + use_bias=False, + )(x) if self.use_layer_norm: x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) - x = nn.Dropout( - self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( - x) + x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)( + x, rate=dropout_rate + ) return x @@ -196,6 +200,7 @@ class TransposeConvBlock(nn.Module): """A Transpose Convolutional Block. out_channels: Number of channels in the output. """ + out_channels: int use_tanh: bool use_layer_norm: bool @@ -209,8 +214,8 @@ def __call__(self, x): jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`. """ x = nn.ConvTranspose( - self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( - x) + self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False + )(x) x = _instance_norm2d(x, (1, 2)) if self.use_tanh: activation_fn = nn.tanh diff --git a/algoperf/workloads/fastmri/fastmri_jax/ssim.py b/algoperf/workloads/fastmri/fastmri_jax/ssim.py index e15b93616..ca2ee1b60 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/ssim.py +++ b/algoperf/workloads/fastmri/fastmri_jax/ssim.py @@ -49,12 +49,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): return ssims -def structural_similarity(im1, - im2, - data_range=1.0, - win_size=7, - k1=0.01, - k2=0.03): +def structural_similarity( + im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03 +): """Compute the mean structural similarity index between two images. NOTE(dsuo): modified from skimage.metrics.structural_similarity. @@ -85,7 +82,7 @@ def structural_similarity(im1, """ filter_func = functools.partial(_uniform_filter, size=win_size) - num_points = win_size**len(im1.shape) + num_points = win_size ** len(im1.shape) # filter has already normalized by num_points cov_norm = num_points / (num_points - 1) # sample covariance @@ -102,8 +99,8 @@ def structural_similarity(im1, vy = cov_norm * (uyy - uy * uy) vxy = cov_norm * (uxy - ux * uy) - c1 = (k1 * data_range)**2 - c2 = (k2 * data_range)**2 + c1 = (k1 * data_range) ** 2 + c2 = (k2 * data_range) ** 2 a1 = 2 * ux * uy + c1 a2 = 2 * vxy + c2 @@ -121,12 +118,15 @@ def structural_similarity(im1, def _uniform_filter(im, size=7): - def conv(im): - return jnp.convolve( + return ( + jnp.convolve( jnp.pad(im, pad_width=size // 2, mode='symmetric'), jnp.ones(size), - mode='valid') / size + mode='valid', + ) + / size + ) im = jax.vmap(conv, (0,))(im) im = jax.vmap(conv, (1,))(im) diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..08bb25014 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -4,38 +4,34 @@ import math from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec import algoperf.random_utils as prng -from algoperf.workloads.fastmri.fastmri_jax.models import UNet +from algoperf import param_utils, spec +from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE, UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload class FastMRIWorkload(BaseFastMRIWorkload): - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + self, + rng: spec.RandomState, + ) -> spec.ModelInitState: """aux_dropout_rate is unused.""" - del aux_dropout_rate fake_batch = jnp.zeros((13, 320, 320)) self._model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) - params_rng, dropout_rng = jax.random.split(rng) - variables = jax.jit( - self._model.init)({'params': params_rng, 'dropout': dropout_rng}, - fake_batch) + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) + + params_rng, _ = jax.random.split(rng) + init_fn = functools.partial(self._model.init, train=False) + variables = jax.jit(init_fn)({'params': params_rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -46,30 +42,37 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Conv_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train) + + logits = self._model.apply( + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate, + ) return logits, None # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -78,8 +81,9 @@ def loss_fn( """ del label_smoothing per_example_losses = jnp.mean( - jnp.abs(logits_batch - label_batch), - axis=tuple(range(1, logits_batch.ndim))) + jnp.abs(logits_batch - label_batch), + axis=tuple(range(1, logits_batch.ndim)), + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -88,56 +92,63 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_model(self, - params: spec.Tensor, - batch: Dict[str, spec.Tensor], - rng: spec.RandomState) -> Dict[str, spec.Tensor]: + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0), + static_broadcasted_argnums=(0,), + ) + def _eval_model( + self, + params: spec.Tensor, + batch: Dict[str, spec.Tensor], + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: """Return the SSIM and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=rng, - update_batch_norm=False) + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) ssim_vals = ssim( - logits, - targets, - mean=batch['mean'], - std=batch['std'], - volume_max=batch['volume_max']) + logits, + targets, + mean=batch['mean'], + std=batch['std'], + volume_max=batch['volume_max'], + ) ssim_sum = jnp.sum(ssim_vals * weights) summed_loss = self.loss_fn(targets, logits, weights)['summed'] metrics = { - 'ssim': ssim_sum, - 'loss': summed_loss, + 'ssim': ssim_sum, + 'loss': summed_loss, } metrics = jax.lax.psum(metrics, axis_name='batch') return metrics - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -146,27 +157,27 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng, - split, - data_dir, - global_batch_size=global_batch_size, - repeat_final_dataset=True, - num_batches=num_batches) - - total_metrics = {'ssim': 0., 'loss': 0.} + data_rng, + split, + data_dir, + global_batch_size=global_batch_size, + repeat_final_dataset=True, + num_batches=num_batches, + ) + + total_metrics = {'ssim': 0.0, 'loss': 0.0} eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. synced_metrics = self._eval_model(params, batch, eval_rngs) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k][0] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} class FastMRIModelSizeWorkload(FastMRIWorkload): - @property def num_pool_layers(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -187,7 +198,6 @@ def test_target_value(self) -> float: class FastMRITanhWorkload(FastMRIWorkload): - @property def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -203,7 +213,6 @@ def test_target_value(self) -> float: class FastMRILayerNormWorkload(FastMRIWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py index 28f20bf20..16cf8bd54 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -5,93 +5,93 @@ """ from functools import partial -from typing import Optional import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torch.nn import functional as F from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout + +DROPOUT_RATE = 0.0 class UNet(nn.Module): r"""U-Net model from - `"U-net: Convolutional networks - for biomedical image segmentation" - `_. - """ - - def __init__(self, - in_chans: int = 1, - out_chans: int = 1, - num_channels: int = 32, - num_pool_layers: int = 4, - dropout_rate: Optional[float] = 0.0, - use_tanh: bool = False, - use_layer_norm: bool = False) -> None: + `"U-net: Convolutional networks + for biomedical image segmentation" + `_. + """ + + def __init__( + self, + in_chans: int = 1, + out_chans: int = 1, + num_channels: int = 32, + num_pool_layers: int = 4, + use_tanh: bool = False, + use_layer_norm: bool = False, + ) -> None: super().__init__() self.in_chans = in_chans self.out_chans = out_chans self.num_channels = num_channels self.num_pool_layers = num_pool_layers - if dropout_rate is None: - dropout_rate = 0.0 - self.down_sample_layers = nn.ModuleList([ - ConvBlock(in_chans, - num_channels, - dropout_rate, - use_tanh, - use_layer_norm) - ]) + + self.down_sample_layers = nn.ModuleList( + [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)] + ) ch = num_channels for _ in range(num_pool_layers - 1): self.down_sample_layers.append( - ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)) + ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) + ) ch *= 2 - self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm) + self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm) self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() for _ in range(num_pool_layers - 1): self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) - self.up_conv.append( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm)) + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm) + ) + self.up_conv.append(ConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) ch //= 2 self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm) + ) self.up_conv.append( - nn.Sequential( - ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm), - nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), - )) + SequentialWithDropout( + ConvBlock(ch * 2, ch, use_tanh, use_layer_norm), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init_utils.pytorch_default_init(m) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor: stack = [] output = x # apply down-sampling layers for layer in self.down_sample_layers: - output = layer(output) + output = layer(output, dropout_rate) stack.append(output) output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) - output = self.conv(output) + output = self.conv(output, dropout_rate) # apply up-sampling layers for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): downsample_layer = stack.pop() output = transpose_conv(output) - # reflect pad on the right/botton if needed to handle + # reflect pad on the right/bottom if needed to handle # odd input dimensions padding = [0, 0, 0, 0] if output.shape[-1] != downsample_layer.shape[-1]: @@ -99,10 +99,10 @@ def forward(self, x: Tensor) -> Tensor: if output.shape[-2] != downsample_layer.shape[-2]: padding[3] = 1 # padding bottom if torch.sum(torch.tensor(padding)) != 0: - output = F.pad(output, padding, "reflect") + output = F.pad(output, padding, 'reflect') output = torch.cat([output, downsample_layer], dim=1) - output = conv(output) + output = conv(output, dropout_rate) return output @@ -111,13 +111,11 @@ class ConvBlock(nn.Module): # A Convolutional Block that consists of two convolution layers each # followed by instance normalization, LeakyReLU activation and dropout_rate. - def __init__(self, - in_chans: int, - out_chans: int, - dropout_rate: float, - use_tanh: bool, - use_layer_norm: bool) -> None: + def __init__( + self, in_chans: int, out_chans: int, use_tanh: bool, use_layer_norm: bool + ) -> None: super().__init__() + self._supports_custom_dropout = True if use_layer_norm: norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) @@ -127,19 +125,19 @@ def __init__(self, activation_fn = nn.Tanh() else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.conv_layers = nn.Sequential( - nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - nn.Dropout2d(dropout_rate), - nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer(out_chans), - activation_fn, - nn.Dropout2d(dropout_rate), + self.conv_layers = SequentialWithDropout( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + norm_layer(out_chans), + activation_fn, + CustomDropout2d(), ) - def forward(self, x: Tensor) -> Tensor: - return self.conv_layers(x) + def forward(self, x: Tensor, dropout_rate: float) -> Tensor: + return self.conv_layers(x, dropout_rate) class TransposeConvBlock(nn.Module): @@ -147,11 +145,11 @@ class TransposeConvBlock(nn.Module): # layers followed by instance normalization and LeakyReLU activation. def __init__( - self, - in_chans: int, - out_chans: int, - use_tanh: bool, - use_layer_norm: bool, + self, + in_chans: int, + out_chans: int, + use_tanh: bool, + use_layer_norm: bool, ): super().__init__() if use_tanh: @@ -159,10 +157,11 @@ def __init__( else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.layers = nn.Sequential( - nn.ConvTranspose2d( - in_chans, out_chans, kernel_size=2, stride=2, bias=False), - nn.InstanceNorm2d(out_chans), - activation_fn, + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + activation_fn, ) def forward(self, x: Tensor) -> Tensor: diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py index 45b61bea4..7d594b959 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py @@ -32,9 +32,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): # NOTE(dsuo): `volume_max` can be 0 if we have a padded batch, but this will # lead to NaN values in `ssim`. - volume_max = torch.where(volume_max == 0, - torch.ones_like(volume_max), - volume_max) + volume_max = torch.where( + volume_max == 0, torch.ones_like(volume_max), volume_max + ) if mean is None: mean = torch.zeros(logits.shape[0], device=DEVICE) @@ -56,12 +56,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): return ssims -def structural_similarity(im1, - im2, - data_range=1.0, - win_size=7, - k1=0.01, - k2=0.03): +def structural_similarity( + im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03 +): """Compute the mean structural similarity index between two images. NOTE(dsuo): modified from skimage.metrics.structural_similarity. @@ -92,7 +89,7 @@ def structural_similarity(im1, """ filter_func = functools.partial(_uniform_filter, size=win_size) - num_points = win_size**len(im1.shape) + num_points = win_size ** len(im1.shape) # filter has already normalized by num_points cov_norm = num_points / (num_points - 1) # sample covariance @@ -109,8 +106,8 @@ def structural_similarity(im1, vy = cov_norm * (uyy - uy * uy) vxy = cov_norm * (uxy - ux * uy) - c1 = (k1 * data_range)**2 - c2 = (k2 * data_range)**2 + c1 = (k1 * data_range) ** 2 + c2 = (k2 * data_range) ** 2 a1 = 2 * ux * uy + c1 a2 = 2 * vxy + c2 diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 58943de2f..bddf6b1f3 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -9,10 +9,9 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec import algoperf.random_utils as prng +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.fastmri.fastmri_pytorch import models from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload @@ -21,28 +20,31 @@ class FastMRIWorkload(BaseFastMRIWorkload): - - def _build_input_queue(self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None): + def _build_input_queue( + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): per_device_batch_size = int(global_batch_size / N_GPUS) # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: data_rng = data_rng.astype('uint32') - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset, - num_batches) + np_iter = super()._build_input_queue( + data_rng, + split, + data_dir, + global_batch_size, + cache, + repeat_final_dataset, + num_batches, + ) while True: if RANK == 0: @@ -58,20 +60,23 @@ def _build_input_queue(self, else: aux_tensor_list.append(tensor) batch[key] = ( - tensor[0] if USE_PYTORCH_DDP else tensor.view( - -1, *value.shape[2:])) + tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, *value.shape[2:]) + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: if split != 'train': # During eval, the batch size of the remainder might be different. per_device_batch_size = torch.tensor( - len(batch['inputs']), dtype=torch.int32, device=DEVICE) + len(batch['inputs']), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) weights = weights if 'weights' in batch else None if weights is None: - weights = torch.ones((N_GPUS, per_device_batch_size), - dtype=torch.float64, - device=DEVICE) + weights = torch.ones( + (N_GPUS, per_device_batch_size), + dtype=torch.float64, + device=DEVICE, + ) # Has no effect, but without it `batch` has no `weights` key # for RANK == 0, but has one for all others. batch['weights'] = weights[0] @@ -82,20 +87,22 @@ def _build_input_queue(self, batch = {} if split != 'train': # During eval, the batch size of the remainder might be different. - per_device_batch_size = torch.empty((), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) - weights = torch.empty((N_GPUS, per_device_batch_size), - dtype=torch.float64, - device=DEVICE) + weights = torch.empty( + (N_GPUS, per_device_batch_size), dtype=torch.float64, device=DEVICE + ) dist.broadcast(weights, src=0) batch['weights'] = weights[RANK] - tensors = torch.empty((2, N_GPUS, per_device_batch_size, 320, 320), - device=DEVICE) + tensors = torch.empty( + (2, N_GPUS, per_device_batch_size, 320, 320), device=DEVICE + ) dist.broadcast(tensors, src=0) - aux_tensors = torch.empty((3, N_GPUS, per_device_batch_size), - device=DEVICE) + aux_tensors = torch.empty( + (3, N_GPUS, per_device_batch_size), device=DEVICE + ) dist.broadcast(aux_tensors, src=0) # Note that the batch dict in the RANK == 0 process is ordered. batch['inputs'] = tensors[0][RANK] @@ -105,19 +112,14 @@ def _build_input_queue(self, batch['volume_max'] = aux_tensors[2][RANK] yield batch - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = UNet( - num_pool_layers=self.num_pool_layers, - num_channels=self.num_channels, - use_tanh=self.use_tanh, - use_layer_norm=self.use_layer_norm, - dropout_rate=dropout_rate) + num_pool_layers=self.num_pool_layers, + num_channels=self.num_channels, + use_tanh=self.use_tanh, + use_layer_norm=self.use_layer_norm, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -132,13 +134,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['up_conv.3.1.weight', 'up_conv.3.1.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -152,25 +156,27 @@ def model_fn( model.train() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logit_batch = model( - augmented_and_preprocessed_input_batch['inputs'].unsqueeze( - 1)).squeeze(1) + augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1), + dropout_rate=dropout_rate, + ).squeeze(1) return logit_batch, None # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -179,7 +185,8 @@ def loss_fn( """ del label_smoothing per_example_losses = F.l1_loss( - logits_batch, label_batch, reduction='none').mean(dim=(1, 2)) + logits_batch, label_batch, reduction='none' + ).mean(dim=(1, 2)) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -188,46 +195,52 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } - def _eval_model(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - rng: spec.RandomState) -> Dict[str, spec.Tensor]: + def _eval_model( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: """Return the SSIM and loss as a dict.""" outputs, _ = self.model_fn( - params, - batch, - None, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + None, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: weights = torch.ones(len(outputs), device=DEVICE) weights_sum = weights.sum().to(torch.int) ssim_sum = ssim( - outputs[:weights_sum], - targets[:weights_sum], - mean=batch['mean'][:weights_sum], - std=batch['std'][:weights_sum], - volume_max=batch['volume_max'][:weights_sum]).sum() + outputs[:weights_sum], + targets[:weights_sum], + mean=batch['mean'][:weights_sum], + std=batch['std'][:weights_sum], + volume_max=batch['volume_max'][:weights_sum], + ).sum() summed_loss = self.loss_fn(targets, outputs, weights)['summed'] return {'ssim': ssim_sum, 'loss': summed_loss} - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -236,22 +249,23 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng, - split, - data_dir, - global_batch_size=global_batch_size, - repeat_final_dataset=True, - num_batches=num_batches) + data_rng, + split, + data_dir, + global_batch_size=global_batch_size, + repeat_final_dataset=True, + num_batches=num_batches, + ) total_metrics = { - 'ssim': torch.tensor(0., device=DEVICE), - 'loss': torch.tensor(0., device=DEVICE), + 'ssim': torch.tensor(0.0, device=DEVICE), + 'loss': torch.tensor(0.0, device=DEVICE), } for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): @@ -260,7 +274,6 @@ def _eval_model_on_split(self, class FastMRIModelSizeWorkload(FastMRIWorkload): - @property def num_pool_layers(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -281,7 +294,6 @@ def test_target_value(self) -> float: class FastMRITanhWorkload(FastMRIWorkload): - @property def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" @@ -297,7 +309,6 @@ def test_target_value(self) -> float: class FastMRILayerNormWorkload(FastMRIWorkload): - @property def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" diff --git a/algoperf/workloads/fastmri/input_pipeline.py b/algoperf/workloads/fastmri/input_pipeline.py index f20611f43..62b3219c5 100644 --- a/algoperf/workloads/fastmri/input_pipeline.py +++ b/algoperf/workloads/fastmri/input_pipeline.py @@ -16,12 +16,9 @@ _EVAL_SEED = 0 -def _process_example(kspace, - kspace_shape, - target, - target_shape, - volume_max, - seed): +def _process_example( + kspace, kspace_shape, target, target_shape, volume_max, seed +): """Generate a single example (slice from mri image). Args: @@ -45,15 +42,17 @@ def _process_example(kspace, acceleration = tf.convert_to_tensor(4.0, dtype=tf.float32) num_low_frequencies = tf.cast( - num_cols_float * center_fraction, dtype=tf.int32) + num_cols_float * center_fraction, dtype=tf.int32 + ) # calculate_center_mask mask = tf.zeros(num_cols, dtype=tf.float32) pad = (num_cols - num_low_frequencies + 1) // 2 mask = tf.tensor_scatter_nd_update( - mask, - tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)), - tf.ones(num_low_frequencies)) + mask, + tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)), + tf.ones(num_low_frequencies), + ) # reshape_mask center_mask = tf.reshape(mask, (1, num_cols)) @@ -61,10 +60,12 @@ def _process_example(kspace, # calculate_acceleration_mask num_low_frequencies_float = tf.cast(num_low_frequencies, dtype=tf.float32) prob = (num_cols_float / acceleration - num_low_frequencies_float) / ( - num_cols_float - num_low_frequencies_float) + num_cols_float - num_low_frequencies_float + ) mask = tf.cast( - tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32) + tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32 + ) acceleration_mask = tf.reshape(mask, (1, num_cols)) mask = tf.math.maximum(center_mask, acceleration_mask) @@ -78,9 +79,11 @@ def _process_example(kspace, shifted_image = tf.signal.ifft2d(shifted_kspace) image = tf.signal.fftshift(shifted_image, axes=(0, 1)) scaling_norm = tf.cast( - tf.math.sqrt( - tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')), - kspace.dtype) + tf.math.sqrt( + tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32') + ), + kspace.dtype, + ) image = image * scaling_norm image = tf.stack((tf.math.real(image), tf.math.imag(image)), axis=-1) @@ -108,48 +111,58 @@ def _process_example(kspace, target = tf.clip_by_value(norm_target, -6, 6) return { - 'inputs': image, - 'targets': target, - 'mean': mean, - 'std': std, - 'volume_max': volume_max, + 'inputs': image, + 'targets': target, + 'mean': mean, + 'std': std, + 'volume_max': volume_max, } def _h5_to_examples(path, log=False): """Yield MRI slices from an hdf5 file containing a single MRI volume.""" if log: - tf.print('fastmri_dataset._h5_to_examples call:', - path, - datetime.datetime.now().strftime('%H:%M:%S:%f')) + tf.print( + 'fastmri_dataset._h5_to_examples call:', + path, + datetime.datetime.now().strftime('%H:%M:%S:%f'), + ) with open(path, 'rb') as gf: with h5py.File(gf, 'r') as hf: # NOTE(dsuo): logic taken from reference code volume_max = hf.attrs.get('max', 0.0) for i in range(hf['kspace'].shape[0]): - yield hf['kspace'][i], hf['kspace'][i].shape, hf['reconstruction_esc'][ - i], hf['reconstruction_esc'][i].shape, volume_max + yield ( + hf['kspace'][i], + hf['kspace'][i].shape, + hf['reconstruction_esc'][i], + hf['reconstruction_esc'][i].shape, + volume_max, + ) def _create_generator(filename): signature = ( - tf.TensorSpec(shape=(640, None), dtype=tf.complex64), - tf.TensorSpec(shape=(2,), dtype=tf.int32), - tf.TensorSpec(shape=(320, 320), dtype=tf.float32), - tf.TensorSpec(shape=(2,), dtype=tf.int32), - tf.TensorSpec(shape=(), dtype=tf.float32), + tf.TensorSpec(shape=(640, None), dtype=tf.complex64), + tf.TensorSpec(shape=(2,), dtype=tf.int32), + tf.TensorSpec(shape=(320, 320), dtype=tf.float32), + tf.TensorSpec(shape=(2,), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=tf.float32), ) return tf.data.Dataset.from_generator( - _h5_to_examples, args=(filename,), output_signature=signature) + _h5_to_examples, args=(filename,), output_signature=signature + ) -def load_fastmri_split(global_batch_size, - split, - data_dir, - shuffle_rng, - num_batches, - repeat_final_eval_dataset): +def load_fastmri_split( + global_batch_size, + split, + data_dir, + shuffle_rng, + num_batches, + repeat_final_eval_dataset, +): """Creates a split from the FastMRI dataset using tf.data. NOTE: only creates knee singlecoil datasets. @@ -169,11 +182,13 @@ def load_fastmri_split(global_batch_size, # Check if data directories exist because glob will not raise an error if not os.path.exists(os.path.join(data_dir, _TRAIN_DIR)): - raise NotADirectoryError('Directory not found: {}'.format( - os.path.join(data_dir, _TRAIN_DIR))) + raise NotADirectoryError( + 'Directory not found: {}'.format(os.path.join(data_dir, _TRAIN_DIR)) + ) if not os.path.exists(os.path.join(data_dir, _VAL_DIR)): - raise NotADirectoryError('Directory not found: {}'.format( - os.path.join(data_dir, _VAL_DIR))) + raise NotADirectoryError( + 'Directory not found: {}'.format(os.path.join(data_dir, _VAL_DIR)) + ) if split in ['train', 'eval_train']: file_pattern = os.path.join(data_dir, _TRAIN_DIR, '*.h5') @@ -190,10 +205,8 @@ def load_fastmri_split(global_batch_size, shuffle = is_train or split == 'eval_train' ds = tf.data.Dataset.from_tensor_slices(h5_paths) ds = ds.interleave( - _create_generator, - cycle_length=32, - block_length=64, - num_parallel_calls=16) + _create_generator, cycle_length=32, block_length=64, num_parallel_calls=16 + ) if is_train: ds = ds.cache() @@ -201,7 +214,8 @@ def process_example(example_index, example): if shuffle: process_rng = tf.cast(jax.random.fold_in(shuffle_rng, 0), tf.int64) process_rng = tf.random.experimental.stateless_fold_in( - process_rng, example_index) + process_rng, example_index + ) else: # NOTE(dsuo): we use fixed randomness for eval. process_rng = tf.cast(jax.random.PRNGKey(_EVAL_SEED), tf.int64) @@ -211,9 +225,8 @@ def process_example(example_index, example): if shuffle: ds = ds.shuffle( - 16 * global_batch_size, - seed=shuffle_rng[0], - reshuffle_each_iteration=True) + 16 * global_batch_size, seed=shuffle_rng[0], reshuffle_each_iteration=True + ) if is_train: ds = ds.repeat() @@ -231,7 +244,8 @@ def process_example(example_index, example): ds = ds.repeat() ds = ds.prefetch(10) return map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 051749cc3..0b1ecfaa1 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -8,7 +8,6 @@ class BaseFastMRIWorkload(spec.Workload): - @property def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" @@ -61,8 +60,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -106,18 +106,22 @@ def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" return 18_094 - def _build_input_queue(self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None): + def _build_input_queue( + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ): del cache - return input_pipeline.load_fastmri_split(global_batch_size, - split, - data_dir, - data_rng, - num_batches, - repeat_final_dataset) + return input_pipeline.load_fastmri_split( + global_batch_size, + split, + data_dir, + data_rng, + num_batches, + repeat_final_dataset, + ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 3d6939218..53368b384 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -1,5 +1,5 @@ """ -Note: +Note: The following code is adapted from: https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image @@ -12,35 +12,39 @@ import tensorflow as tf _IMAGE_DTYPES = { - tf.dtypes.uint8, - tf.dtypes.int32, - tf.dtypes.int64, - tf.dtypes.float16, - tf.dtypes.float32, - tf.dtypes.float64, + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, } -Number = Union[float, - int, - np.float16, - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64,] - -TensorLike = Union[List[Union[Number, list]], - tuple, - Number, - np.ndarray, - tf.Tensor, - tf.SparseTensor, - tf.Variable,] +Number = Union[ + float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, +] + +TensorLike = Union[ + List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable, +] def get_ndims(image): @@ -50,16 +54,19 @@ def get_ndims(image): def to_4d_image(image): """Convert 2/3/4D image to 4D image. - Args: - image: 2/3/4D `Tensor`. + Args: + image: 2/3/4D `Tensor`. - Returns: - 4D `Tensor` with the same type. - """ - with tf.control_dependencies([ + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies( + [ tf.debugging.assert_rank_in( - image, [2, 3, 4], message="`image` must be 2/3/4D tensor") - ]): + image, [2, 3, 4], message='`image` must be 2/3/4D tensor' + ) + ] + ): ndims = image.get_shape().ndims if ndims is None: return _dynamic_to_4d_image(image) @@ -80,12 +87,12 @@ def _dynamic_to_4d_image(image): left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) new_shape = tf.concat( - [ - tf.ones(shape=left_pad, dtype=tf.int32), - shape, - tf.ones(shape=right_pad, dtype=tf.int32), - ], - axis=0, + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, ) return tf.reshape(image, new_shape) @@ -93,16 +100,16 @@ def _dynamic_to_4d_image(image): def from_4d_image(image, ndims): """Convert back to an image with `ndims` rank. - Args: - image: 4D `Tensor`. - ndims: The original rank of the image. + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. - Returns: - `ndims`-D `Tensor` with the same type. - """ + Returns: + `ndims`-D `Tensor` with the same type. + """ with tf.control_dependencies( - [tf.debugging.assert_rank(image, 4, - message="`image` must be 4D tensor")]): + [tf.debugging.assert_rank(image, 4, message='`image` must be 4D tensor')] + ): if isinstance(ndims, tf.Tensor): return _dynamic_from_4d_image(image, ndims) elif ndims == 2: @@ -125,63 +132,64 @@ def _dynamic_from_4d_image(image, original_rank): def transform( - images: TensorLike, - transforms: TensorLike, - interpolation: str = "nearest", - fill_mode: str = "constant", - output_shape: Optional[list] = None, - name: Optional[str] = None, - fill_value: TensorLike = 0.0, + images: TensorLike, + transforms: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Applies the given transform(s) to the image(s). - Args: - images: A tensor of shape (num_images, num_rows, num_columns, - num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). - transforms: Projective transform matrix/matrices. A vector of length 8 or - tensor of size N x 8. If one row of transforms is - [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point - `(x, y)` to a transformed *input* point - `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, - where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to - the transform mapping input points to output points. Note that - gradients are not backpropagated into transformation parameters. - interpolation: Interpolation mode. - Supported values: "nearest", "bilinear". - fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` - The input is extended by filling all values beyond the edge with the - same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` - The input is extended by the nearest pixel. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode` is "constant". - output_shape: Output dimesion after the transform, [height, width]. - If None, output is the same size as input image. - - name: The name of the op. - - Returns: - Image(s) with the same type and shape as `images`, with the given - transform(s) applied. Transformed coordinates outside of the input image - will be filled with zeros. - - Raises: - TypeError: If `image` is an invalid type. - ValueError: If output shape is not 1-D int32 Tensor. - """ - with tf.name_scope(name or "transform"): - image_or_images = tf.convert_to_tensor(images, name="images") + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or 'transform'): + image_or_images = tf.convert_to_tensor(images, name='images') transform_or_transforms = tf.convert_to_tensor( - transforms, name="transforms", dtype=tf.dtypes.float32) + transforms, name='transforms', dtype=tf.dtypes.float32 + ) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) + raise TypeError('Invalid dtype %s.' % image_or_images.dtype) images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) @@ -189,61 +197,67 @@ def transform( output_shape = tf.shape(images)[1:3] output_shape = tf.convert_to_tensor( - output_shape, tf.dtypes.int32, name="output_shape") + output_shape, tf.dtypes.int32, name='output_shape' + ) if not output_shape.get_shape().is_compatible_with([2]): - raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " - "new_height, new_width") + raise ValueError( + 'output_shape must be a 1-D Tensor of 2 elements: new_height, new_width' + ) if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: - raise ValueError("transforms rank must be statically known") + raise ValueError('transforms rank must be statically known') elif len(transform_or_transforms.get_shape()) == 2: transforms = transform_or_transforms else: transforms = transform_or_transforms - raise ValueError("transforms should have rank 1 or 2, but got rank %d" % - len(transforms.get_shape())) + raise ValueError( + 'transforms should have rank 1 or 2, but got rank %d' + % len(transforms.get_shape()) + ) fill_value = tf.convert_to_tensor( - fill_value, dtype=tf.float32, name="fill_value") + fill_value, dtype=tf.float32, name='fill_value' + ) output = tf.raw_ops.ImageProjectiveTransformV3( - images=images, - transforms=transforms, - output_shape=output_shape, - interpolation=interpolation.upper(), - fill_mode=fill_mode.upper(), - fill_value=fill_value, + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, ) return from_4d_image(output, original_ndims) def angles_to_projective_transforms( - angles: TensorLike, - image_height: TensorLike, - image_width: TensorLike, - name: Optional[str] = None, + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, ) -> tf.Tensor: """Returns projective transform(s) for the given angle(s). - Args: - angles: A scalar angle to rotate all images by, or (for batches of - images) a vector with an angle to rotate each image in the batch. The - rank must be statically known (the shape is not `TensorShape(None)`. - image_height: Height of the image(s) to be transformed. - image_width: Width of the image(s) to be transformed. - - Returns: - A tensor of shape (num_images, 8). Projective transforms which can be - given to `transform` op. - """ - with tf.name_scope(name or "angles_to_projective_transforms"): + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or 'angles_to_projective_transforms'): angle_or_angles = tf.convert_to_tensor( - angles, name="angles", dtype=tf.dtypes.float32) + angles, name='angles', dtype=tf.dtypes.float32 + ) if len(angle_or_angles.get_shape()) not in (0, 1): - raise ValueError("angles should have rank 0 or 1.") + raise ValueError('angles should have rank 0 or 1.') if len(angle_or_angles.get_shape()) == 0: angles = angle_or_angles[None] @@ -252,112 +266,116 @@ def angles_to_projective_transforms( cos_angles = tf.math.cos(angles) sin_angles = tf.math.sin(angles) - x_offset = ((image_width - 1) - - (cos_angles * (image_width - 1) - sin_angles * - (image_height - 1))) / 2.0 - y_offset = ((image_height - 1) - - (sin_angles * (image_width - 1) + cos_angles * - (image_height - 1))) / 2.0 + x_offset = ( + (image_width - 1) + - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1)) + ) / 2.0 + y_offset = ( + (image_height - 1) + - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1)) + ) / 2.0 num_angles = tf.shape(angles)[0] return tf.concat( - values=[ - cos_angles[:, None], - -sin_angles[:, None], - x_offset[:, None], - sin_angles[:, None], - cos_angles[:, None], - y_offset[:, None], - tf.zeros((num_angles, 2), tf.dtypes.float32), - ], - axis=1, + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, ) def rotate_img( - images: TensorLike, - angles: TensorLike, - interpolation: str = "nearest", - fill_mode: str = "constant", - name: Optional[str] = None, - fill_value: TensorLike = 0.0, + images: TensorLike, + angles: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + name: Optional[str] = None, + fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Rotate image(s) counterclockwise by the passed angle(s) in radians. - Args: - images: A tensor of shape - `(num_images, num_rows, num_columns, num_channels)` - (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or - `(num_rows, num_columns)` (HW). - angles: A scalar angle to rotate all images by (if `images` has rank 4) - a vector of length num_images, with an angle for each image in the - batch. - interpolation: Interpolation mode. Supported values: "nearest", - "bilinear". - fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` - The input is extended by filling all values beyond the edge with the - same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` - The input is extended by the nearest pixel. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode` is "constant". - name: The name of the op. - - Returns: - Image(s) with the same type and shape as `images`, rotated by the given - angle(s). Empty space due to the rotation will be filled with zeros. - - Raises: - TypeError: If `images` is an invalid type. - """ - with tf.name_scope(name or "rotate"): + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or 'rotate'): image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) + raise TypeError('Invalid dtype %s.' % image_or_images.dtype) images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] output = transform( - images, - angles_to_projective_transforms(angles, image_height, image_width), - interpolation=interpolation, - fill_mode=fill_mode, - fill_value=fill_value, + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, ) return from_4d_image(output, original_ndims) -def translations_to_projective_transforms(translations: TensorLike, - name: Optional[str] = None - ) -> tf.Tensor: +def translations_to_projective_transforms( + translations: TensorLike, name: Optional[str] = None +) -> tf.Tensor: """Returns projective transform(s) for the given translation(s). - Args: - translations: A 2-element list representing `[dx, dy]` or a matrix of - 2-element lists representing `[dx, dy]` to translate for each image - (for a batch of images). The rank must be statically known - (the shape is not `TensorShape(None)`). - name: The name of the op. - Returns: - A tensor of shape `(num_images, 8)` projective transforms which can be - given to `tfa.image.transform`. - """ - with tf.name_scope(name or "translations_to_projective_transforms"): + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or 'translations_to_projective_transforms'): translation_or_translations = tf.convert_to_tensor( - translations, name="translations", dtype=tf.dtypes.float32) + translations, name='translations', dtype=tf.dtypes.float32 + ) if translation_or_translations.get_shape().ndims is None: raise TypeError( - "translation_or_translations rank must be statically known") + 'translation_or_translations rank must be statically known' + ) if len(translation_or_translations.get_shape()) not in (1, 2): - raise TypeError("Translations should have rank 1 or 2.") + raise TypeError('Translations should have rank 1 or 2.') if len(translation_or_translations.get_shape()) == 1: translations = translation_or_translations[None] @@ -372,67 +390,67 @@ def translations_to_projective_transforms(translations: TensorLike, # where the last entry is implicit. # Translation matrices are always float32. return tf.concat( - values=[ - tf.ones((num_translations, 1), tf.dtypes.float32), - tf.zeros((num_translations, 1), tf.dtypes.float32), - -translations[:, 0, None], - tf.zeros((num_translations, 1), tf.dtypes.float32), - tf.ones((num_translations, 1), tf.dtypes.float32), - -translations[:, 1, None], - tf.zeros((num_translations, 2), tf.dtypes.float32), - ], - axis=1, + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, ) @tf.function def translate( - images: TensorLike, - translations: TensorLike, - interpolation: str = "nearest", - fill_mode: str = "constant", - name: Optional[str] = None, - fill_value: TensorLike = 0.0, + images: TensorLike, + translations: TensorLike, + interpolation: str = 'nearest', + fill_mode: str = 'constant', + name: Optional[str] = None, + fill_value: TensorLike = 0.0, ) -> tf.Tensor: """Translate image(s) by the passed vectors(s). - Args: - images: A tensor of shape - `(num_images, num_rows, num_columns, num_channels)` (NHWC), - `(num_rows, num_columns, num_channels)` (HWC), or - `(num_rows, num_columns)` (HW). The rank must be statically known (the - shape is not `TensorShape(None)`). - translations: A vector representing `[dx, dy]` or (if `images` has rank 4) - a matrix of length num_images, with a `[dx, dy]` vector for each image - in the batch. - interpolation: Interpolation mode. Supported values: "nearest", - "bilinear". - fill_mode: Points outside the boundaries of the input are filled according - to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). - - *reflect*: `(d c b a | a b c d | d c b a)` - The input is extended by reflecting about the edge of the last pixel. - - *constant*: `(k k k k | a b c d | k k k k)` - The input is extended by filling all values beyond the edge with the - same constant value k = 0. - - *wrap*: `(a b c d | a b c d | a b c d)` - The input is extended by wrapping around to the opposite edge. - - *nearest*: `(a a a a | a b c d | d d d d)` - The input is extended by the nearest pixel. - fill_value: a float represents the value to be filled outside the - boundaries when `fill_mode` is "constant". - name: The name of the op. - Returns: - Image(s) with the same type and shape as `images`, translated by the - given vector(s). Empty space due to the translation will be filled with - zeros. - Raises: - TypeError: If `images` is an invalid type. - """ - with tf.name_scope(name or "translate"): + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or 'translate'): return transform( - images, - translations_to_projective_transforms(translations), - interpolation=interpolation, - fill_mode=fill_mode, - fill_value=fill_value, + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..f782e50a1 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -7,29 +7,30 @@ import functools from typing import Dict, Iterator, Tuple -from flax import jax_utils import jax import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils -from algoperf import data_utils -from algoperf import spec +from algoperf import data_utils, spec from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment TFDS_SPLIT_NAME = { - 'train': 'train', 'eval_train': 'train', 'validation': 'validation' + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', } -def _distorted_bounding_box_crop(image_bytes: spec.Tensor, - rng: spec.RandomState, - bbox: spec.Tensor, - min_object_covered: float = 0.1, - aspect_ratio_range: Tuple[float, - float] = (0.75, - 1.33), - area_range: Tuple[float, float] = (0.05, 1.0), - max_attempts: int = 100) -> spec.Tensor: +def _distorted_bounding_box_crop( + image_bytes: spec.Tensor, + rng: spec.RandomState, + bbox: spec.Tensor, + min_object_covered: float = 0.1, + aspect_ratio_range: Tuple[float, float] = (0.75, 1.33), + area_range: Tuple[float, float] = (0.05, 1.0), + max_attempts: int = 100, +) -> spec.Tensor: """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. @@ -57,14 +58,15 @@ def _distorted_bounding_box_crop(image_bytes: spec.Tensor, """ shape = tf.io.extract_jpeg_shape(image_bytes) bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box( - shape, - seed=rng, - bounding_boxes=bbox, - min_object_covered=min_object_covered, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - max_attempts=max_attempts, - use_image_if_no_bounding_boxes=True) + shape, + seed=rng, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True, + ) # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) @@ -84,8 +86,9 @@ def resize(image: spec.Tensor, image_size: int) -> spec.Tensor: Returns: Resized image 'Tensor'. """ - return tf.image.resize([image], [image_size, image_size], - method=tf.image.ResizeMethod.BICUBIC)[0] + return tf.image.resize( + [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC + )[0] def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool: @@ -95,80 +98,93 @@ def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool: return tf.greater_equal(tf.reduce_sum(match), x) -def _decode_and_random_crop(image_bytes: spec.Tensor, - rng: spec.RandomState, - image_size: int, - aspect_ratio_range: Tuple[float, float], - area_range: Tuple[float, float], - resize_size: int) -> spec.Tensor: +def _decode_and_random_crop( + image_bytes: spec.Tensor, + rng: spec.RandomState, + image_size: int, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + resize_size: int, +) -> spec.Tensor: """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = _distorted_bounding_box_crop( - image_bytes, - rng, - bbox, - min_object_covered=0.1, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - max_attempts=10) + image_bytes, + rng, + bbox, + min_object_covered=0.1, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=10, + ) original_shape = tf.io.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) image = tf.cond( - bad, - lambda: _decode_and_center_crop(image_bytes, image_size, resize_size), - lambda: resize(image, image_size)) + bad, + lambda: _decode_and_center_crop(image_bytes, image_size, resize_size), + lambda: resize(image, image_size), + ) return image -def _decode_and_center_crop(image_bytes: spec.Tensor, - image_size: int, - resize_size: int) -> spec.Tensor: +def _decode_and_center_crop( + image_bytes: spec.Tensor, image_size: int, resize_size: int +) -> spec.Tensor: """Crops to center of image with padding then scales image_size.""" shape = tf.io.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( - ((image_size / resize_size) * - tf.cast(tf.minimum(image_height, image_width), tf.float32)), - tf.int32) + ( + (image_size / resize_size) + * tf.cast(tf.minimum(image_height, image_width), tf.float32) + ), + tf.int32, + ) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 - crop_window = tf.stack([ + crop_window = tf.stack( + [ offset_height, offset_width, padded_center_crop_size, padded_center_crop_size, - ]) + ] + ) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = resize(image, image_size) return image -def normalize_image(image: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float]) -> spec.Tensor: +def normalize_image( + image: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], +) -> spec.Tensor: image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) return image -def preprocess_for_train(image_bytes: spec.Tensor, - rng: spec.RandomState, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - aspect_ratio_range: Tuple[float, float], - area_range: Tuple[float, float], - image_size: int, - resize_size: int, - dtype: tf.DType = tf.float32, - use_randaug: bool = False, - randaug_num_layers: int = 2, - randaug_magnitude: int = 10) -> spec.Tensor: +def preprocess_for_train( + image_bytes: spec.Tensor, + rng: spec.RandomState, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + image_size: int, + resize_size: int, + dtype: tf.DType = tf.float32, + use_randaug: bool = False, + randaug_num_layers: int = 2, + randaug_magnitude: int = 10, +) -> spec.Tensor: """Preprocesses the given image for training. Args: @@ -182,33 +198,36 @@ def preprocess_for_train(image_bytes: spec.Tensor, """ rngs = tf.random.experimental.stateless_split(rng, 3) - image = _decode_and_random_crop(image_bytes, - rngs[0], - image_size, - aspect_ratio_range, - area_range, - resize_size) + image = _decode_and_random_crop( + image_bytes, + rngs[0], + image_size, + aspect_ratio_range, + area_range, + resize_size, + ) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.stateless_random_flip_left_right(image, seed=rngs[1]) if use_randaug: image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) - image = randaugment.distort_image_with_randaugment(image, - randaug_num_layers, - randaug_magnitude, - rngs[2]) + image = randaugment.distort_image_with_randaugment( + image, randaug_num_layers, randaug_magnitude, rngs[2] + ) image = tf.cast(image, tf.float32) image = normalize_image(image, mean_rgb, stddev_rgb) image = tf.image.convert_image_dtype(image, dtype=dtype) return image -def preprocess_for_eval(image_bytes: spec.Tensor, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - image_size: int, - resize_size: int, - dtype: tf.DType = tf.float32) -> spec.Tensor: +def preprocess_for_eval( + image_bytes: spec.Tensor, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, + dtype: tf.DType = tf.float32, +) -> spec.Tensor: """Preprocesses the given image for evaluation. Args: @@ -229,10 +248,12 @@ def preprocess_for_eval(image_bytes: spec.Tensor, # Modified from # github.com/google/init2winit/blob/master/init2winit/dataset_lib/ (cont. below) # image_preprocessing.py. -def mixup_tf(key: spec.RandomState, - inputs: spec.Tensor, - targets: spec.Tensor, - alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]: +def mixup_tf( + key: spec.RandomState, + inputs: spec.Tensor, + targets: spec.Tensor, + alpha: float = 0.2, +) -> Tuple[spec.Tensor, spec.Tensor]: """Perform mixup https://arxiv.org/abs/1710.09412. NOTE: Code taken from https://github.com/google/big_vision with variables @@ -261,24 +282,26 @@ def mixup_tf(key: spec.RandomState, return inputs, targets -def create_split(split, - dataset_builder, - rng, - global_batch_size, - train, - image_size, - resize_size, - mean_rgb, - stddev_rgb, - cache=False, - repeat_final_dataset=False, - aspect_ratio_range=(0.75, 4.0 / 3.0), - area_range=(0.08, 1.0), - use_mixup=False, - mixup_alpha=0.1, - use_randaug=False, - randaug_num_layers=2, - randaug_magnitude=10) -> Iterator[Dict[str, spec.Tensor]]: +def create_split( + split, + dataset_builder, + rng, + global_batch_size, + train, + image_size, + resize_size, + mean_rgb, + stddev_rgb, + cache=False, + repeat_final_dataset=False, + aspect_ratio_range=(0.75, 4.0 / 3.0), + area_range=(0.08, 1.0), + use_mixup=False, + mixup_alpha=0.1, + use_randaug=False, + randaug_num_layers=2, + randaug_magnitude=10, +) -> Iterator[Dict[str, spec.Tensor]]: """Creates a split from the ImageNet dataset using TensorFlow Datasets.""" shuffle_rng, preprocess_rng, mixup_rng = jax.random.split(rng, 3) @@ -286,34 +309,35 @@ def decode_example(example_index, example): dtype = tf.float32 if train: per_step_preprocess_rng = tf.random.experimental.stateless_fold_in( - tf.cast(preprocess_rng, tf.int64), example_index) - - image = preprocess_for_train(example['image'], - per_step_preprocess_rng, - mean_rgb, - stddev_rgb, - aspect_ratio_range, - area_range, - image_size, - resize_size, - dtype, - use_randaug, - randaug_num_layers, - randaug_magnitude) + tf.cast(preprocess_rng, tf.int64), example_index + ) + + image = preprocess_for_train( + example['image'], + per_step_preprocess_rng, + mean_rgb, + stddev_rgb, + aspect_ratio_range, + area_range, + image_size, + resize_size, + dtype, + use_randaug, + randaug_num_layers, + randaug_magnitude, + ) else: - image = preprocess_for_eval(example['image'], - mean_rgb, - stddev_rgb, - image_size, - resize_size, - dtype) + image = preprocess_for_eval( + example['image'], mean_rgb, stddev_rgb, image_size, resize_size, dtype + ) return {'inputs': image, 'targets': example['label']} ds = dataset_builder.as_dataset( - split=TFDS_SPLIT_NAME[split], - decoders={ - 'image': tfds.decode.SkipDecoding(), - }) + split=TFDS_SPLIT_NAME[split], + decoders={ + 'image': tfds.decode.SkipDecoding(), + }, + ) options = tf.data.Options() options.threading.private_threadpool_size = 48 ds = ds.with_options(options) @@ -336,18 +360,21 @@ def decode_example(example_index, example): def mixup_batch(batch_index, batch): per_batch_mixup_rng = tf.random.experimental.stateless_fold_in( - mixup_rng, batch_index) + mixup_rng, batch_index + ) (inputs, targets) = mixup_tf( - per_batch_mixup_rng, - batch['inputs'], - batch['targets'], - alpha=mixup_alpha) + per_batch_mixup_rng, + batch['inputs'], + batch['targets'], + alpha=mixup_alpha, + ) batch['inputs'] = inputs batch['targets'] = targets return batch ds = ds.enumerate().map( - mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) + mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) else: raise ValueError('Mixup can only be used for the training split.') @@ -359,44 +386,48 @@ def mixup_batch(batch_index, batch): return ds -def create_input_iter(split: str, - dataset_builder: tfds.core.dataset_builder.DatasetBuilder, - rng: spec.RandomState, - global_batch_size: int, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - image_size: int, - resize_size: int, - aspect_ratio_range: Tuple[float, float], - area_range: Tuple[float, float], - train: bool, - cache: bool, - repeat_final_dataset: bool, - use_mixup: bool, - mixup_alpha: float, - use_randaug: bool) -> Iterator[Dict[str, spec.Tensor]]: +def create_input_iter( + split: str, + dataset_builder: tfds.core.dataset_builder.DatasetBuilder, + rng: spec.RandomState, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, + aspect_ratio_range: Tuple[float, float], + area_range: Tuple[float, float], + train: bool, + cache: bool, + repeat_final_dataset: bool, + use_mixup: bool, + mixup_alpha: float, + use_randaug: bool, +) -> Iterator[Dict[str, spec.Tensor]]: ds = create_split( - split, - dataset_builder, - rng, - global_batch_size, - train=train, - image_size=image_size, - resize_size=resize_size, - mean_rgb=mean_rgb, - stddev_rgb=stddev_rgb, - cache=cache, - repeat_final_dataset=repeat_final_dataset, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - use_mixup=use_mixup, - mixup_alpha=mixup_alpha, - use_randaug=use_randaug) + split, + dataset_builder, + rng, + global_batch_size, + train=train, + image_size=image_size, + resize_size=resize_size, + mean_rgb=mean_rgb, + stddev_rgb=stddev_rgb, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + use_mixup=use_mixup, + mixup_alpha=mixup_alpha, + use_randaug=use_randaug, + ) it = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. it = jax_utils.prefetch_to_device(it, 2) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index ffa60b260..84ad4fe21 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -7,8 +7,8 @@ import functools from typing import Any, Callable, Optional, Tuple -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn from algoperf import spec @@ -17,12 +17,13 @@ class ResNetBlock(nn.Module): """ResNet block.""" + filters: int conv: ModuleDef norm: ModuleDef act: Callable strides: Tuple[int, int] = (1, 1) - bn_init_scale: float = 0. + bn_init_scale: float = 0.0 @nn.compact def __call__(self, x: spec.Tensor) -> spec.Tensor: @@ -35,8 +36,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor: if residual.shape != y.shape or self.strides != (1, 1): residual = self.conv( - self.filters, (1, 1), self.strides, name='Conv_proj')( - residual) + self.filters, (1, 1), self.strides, name='Conv_proj' + )(residual) residual = self.norm(name='BatchNorm_proj')(residual) return self.act(residual + y) @@ -44,6 +45,7 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor: class BottleneckResNetBlock(nn.Module): """Bottleneck ResNet block.""" + filters: int conv: ModuleDef norm: ModuleDef @@ -65,8 +67,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor: if residual.shape != y.shape or self.strides != (1, 1): residual = self.conv( - self.filters * 4, (1, 1), self.strides, name='Conv_proj')( - residual) + self.filters * 4, (1, 1), self.strides, name='Conv_proj' + )(residual) residual = self.norm(name='BatchNorm_proj')(residual) return self.act(residual + y) @@ -79,30 +81,35 @@ class ResNet(nn.Module): num_filters: int = 64 dtype: Any = jnp.float32 act: Callable = nn.relu - bn_init_scale: float = 0. + bn_init_scale: float = 0.0 @nn.compact - def __call__(self, - x: spec.Tensor, - update_batch_norm: bool = True, - use_running_average_bn: Optional[bool] = None) -> spec.Tensor: + def __call__( + self, + x: spec.Tensor, + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None, + ) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm norm = functools.partial( - nn.BatchNorm, - use_running_average=use_running_average_bn, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype) + nn.BatchNorm, + use_running_average=use_running_average_bn, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype, + ) x = conv( - self.num_filters, (7, 7), (2, 2), - padding=[(3, 3), (3, 3)], - name='Conv_init')( - x) + self.num_filters, + (7, 7), + (2, 2), + padding=[(3, 3), (3, 3)], + name='Conv_init', + )(x) x = norm(name='BatchNorm_init')(x) x = self.act(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) @@ -110,23 +117,23 @@ def __call__(self, for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls( - self.num_filters * 2**i, - strides=strides, - conv=conv, - norm=norm, - act=self.act, - bn_init_scale=self.bn_init_scale)( - x) + self.num_filters * 2**i, + strides=strides, + conv=conv, + norm=norm, + act=self.act, + bn_init_scale=self.bn_init_scale, + )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - dtype=self.dtype)( - x) + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + )(x) return x ResNet18 = functools.partial( - ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) + ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock +) ResNet50 = functools.partial( - ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock) + ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock +) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index c68e2de33..87e218a0c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,16 +9,15 @@ import tensorflow as tf -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ - rotate_img -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ - transform -from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ - translate +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import ( + rotate_img, + transform, + translate, +) # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. -_MAX_LEVEL = 10. +_MAX_LEVEL = 10.0 def blend(image1, image2, factor): @@ -86,10 +85,12 @@ def cutout(image, pad_size, replace=0): # Sample the center location in the image where the zero mask will be applied. cutout_center_height = tf.random.uniform( - shape=[], minval=0, maxval=image_height, dtype=tf.int32) + shape=[], minval=0, maxval=image_height, dtype=tf.int32 + ) cutout_center_width = tf.random.uniform( - shape=[], minval=0, maxval=image_width, dtype=tf.int32) + shape=[], minval=0, maxval=image_width, dtype=tf.int32 + ) lower_pad = tf.maximum(0, cutout_center_height - pad_size) upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) @@ -97,20 +98,18 @@ def cutout(image, pad_size, replace=0): right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) cutout_shape = [ - image_height - (lower_pad + upper_pad), - image_width - (left_pad + right_pad), + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad), ] padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] mask = tf.pad( - tf.zeros(cutout_shape, dtype=image.dtype), - padding_dims, - constant_values=1) + tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1 + ) mask = tf.expand_dims(mask, -1) mask = tf.tile(mask, [1, 1, 3]) image = tf.where( - tf.equal(mask, 0), - tf.ones_like(image, dtype=image.dtype) * replace, - image) + tf.equal(mask, 0), tf.ones_like(image, dtype=image.dtype) * replace, image + ) return image @@ -204,7 +203,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) return unwrap(image, replace) @@ -214,7 +213,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0]) return unwrap(image, replace) @@ -264,9 +263,12 @@ def sharpness(image, factor): # Make image 4D for conv operation. image = tf.expand_dims(image, 0) # SMOOTH PIL Kernel. - kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]], - dtype=tf.float32, - shape=[3, 3, 1, 1]) / 13. + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] + ) + / 13.0 + ) # Tile across channel dimension. kernel = tf.tile(kernel, [1, 1, 3, 1]) strides = [1, 1, 1, 1] @@ -274,7 +276,8 @@ def sharpness(image, factor): # Some augmentation that uses depth-wise conv will cause crashing when # training on GPU. degenerate = tf.nn.depthwise_conv2d( - image, kernel, strides, padding='VALID', dilations=[1, 1]) + image, kernel, strides, padding='VALID', dilations=[1, 1] + ) degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) @@ -316,9 +319,10 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, - lambda: tf.gather(build_lut(histo, step), im)) + tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im), + ) return tf.cast(result, tf.uint8) @@ -373,9 +377,10 @@ def unwrap(image, replace): # Where they are zero, fill them in with 'replace'. flattened_image = tf.where( - tf.equal(alpha_channel, 0), - tf.ones_like(flattened_image, dtype=image.dtype) * replace, - flattened_image) + tf.equal(alpha_channel, 0), + tf.ones_like(flattened_image, dtype=image.dtype) * replace, + flattened_image, + ) image = tf.reshape(flattened_image, image_shape) image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) @@ -383,22 +388,22 @@ def unwrap(image, replace): NAME_TO_FUNC = { - 'AutoContrast': autocontrast, - 'Equalize': equalize, - 'Invert': invert, - 'Rotate': rotate, - 'Posterize': posterize, - 'Solarize': solarize, - 'SolarizeAdd': solarize_add, - 'Color': color, - 'Contrast': contrast, - 'Brightness': brightness, - 'Sharpness': sharpness, - 'ShearX': shear_x, - 'ShearY': shear_y, - 'TranslateX': translate_x, - 'TranslateY': translate_y, - 'Cutout': cutout, + 'AutoContrast': autocontrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x, + 'TranslateY': translate_y, + 'Cutout': cutout, } @@ -410,7 +415,7 @@ def _randomly_negate_tensor(tensor): def _rotate_level_to_arg(level): - level = (level / _MAX_LEVEL) * 30. + level = (level / _MAX_LEVEL) * 30.0 level = _randomly_negate_tensor(level) return (level,) @@ -435,47 +440,28 @@ def _translate_level_to_arg(level, translate_const): def level_to_arg(cutout_const, translate_const): return { - 'AutoContrast': - lambda level: (), - 'Equalize': - lambda level: (), - 'Invert': - lambda level: (), - 'Rotate': - _rotate_level_to_arg, - 'Posterize': - lambda level: (int((level / _MAX_LEVEL) * 4),), - 'Solarize': - lambda level: (int((level / _MAX_LEVEL) * 256),), - 'SolarizeAdd': - lambda level: (int((level / _MAX_LEVEL) * 110),), - 'Color': - _enhance_level_to_arg, - 'Contrast': - _enhance_level_to_arg, - 'Brightness': - _enhance_level_to_arg, - 'Sharpness': - _enhance_level_to_arg, - 'ShearX': - _shear_level_to_arg, - 'ShearY': - _shear_level_to_arg, - 'Cutout': - lambda level: (int((level / _MAX_LEVEL) * cutout_const),), - 'TranslateX': - lambda level: _translate_level_to_arg(level, translate_const), - 'TranslateY': - lambda level: _translate_level_to_arg(level, translate_const), + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'Cutout': lambda level: (int((level / _MAX_LEVEL) * cutout_const),), + 'TranslateX': lambda level: _translate_level_to_arg(level, translate_const), + 'TranslateY': lambda level: _translate_level_to_arg(level, translate_const), } -def _parse_policy_info(name, - prob, - level, - replace_value, - cutout_const, - translate_const): +def _parse_policy_info( + name, prob, level, replace_value, cutout_const, translate_const +): """Return the function that corresponds to `name` and update `level` param.""" func = NAME_TO_FUNC[name] args = level_to_arg(cutout_const, translate_const)[name](level) @@ -514,45 +500,49 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): """ replace_value = [128] * 3 available_ops = [ - 'AutoContrast', - 'Equalize', - 'Invert', - 'Rotate', - 'Posterize', - 'Solarize', - 'Color', - 'Contrast', - 'Brightness', - 'Sharpness', - 'ShearX', - 'ShearY', - 'TranslateX', - 'TranslateY', - 'Cutout', - 'SolarizeAdd', + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'Posterize', + 'Solarize', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateX', + 'TranslateY', + 'Cutout', + 'SolarizeAdd', ] for layer_num in range(num_layers): key = tf.random.experimental.stateless_fold_in(key, layer_num) - op_to_select = tf.random.stateless_uniform([], - seed=key, - maxval=len(available_ops), - dtype=tf.int32) + op_to_select = tf.random.stateless_uniform( + [], seed=key, maxval=len(available_ops), dtype=tf.int32 + ) random_magnitude = float(magnitude) with tf.name_scope('randaug_layer_{}'.format(layer_num)): - for (i, op_name) in enumerate(available_ops): + for i, op_name in enumerate(available_ops): key = tf.random.experimental.stateless_fold_in(key, i) - prob = tf.random.stateless_uniform([], - seed=key, - minval=0.2, - maxval=0.8, - dtype=tf.float32) - func, _, args = _parse_policy_info(op_name, prob, random_magnitude, - replace_value, cutout_const=40, - translate_const=100) + prob = tf.random.stateless_uniform( + [], seed=key, minval=0.2, maxval=0.8, dtype=tf.float32 + ) + func, _, args = _parse_policy_info( + op_name, + prob, + random_magnitude, + replace_value, + cutout_const=40, + translate_const=100, + ) image = tf.cond( - tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + tf.equal(i, op_to_select), + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args + ), + lambda: image, + ) return image diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..c3035c212 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -9,70 +9,75 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils -from flax import linen as nn -from flax.core import pop import jax -from jax import lax import jax.numpy as jnp import optax import tensorflow_datasets as tfds +from flax import jax_utils +from flax import linen as nn +from flax.core import pop +from jax import lax -from algoperf import param_utils +from algoperf import param_utils, spec from algoperf import random_utils as prng -from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline -from algoperf.workloads.imagenet_resnet.imagenet_jax import models -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax import ( + input_pipeline, + models, +) +from algoperf.workloads.imagenet_resnet.workload import ( + BaseImagenetResNetWorkload, +) class ImagenetResNetWorkload(BaseImagenetResNetWorkload): - def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: if split == 'test': np_iter = imagenet_v2.get_imagenet_v2_iter( - data_dir, - global_batch_size, - mean_rgb=self.train_mean, - stddev_rgb=self.train_stddev, - image_size=self.center_crop_size, - resize_size=self.resize_size) + data_dir, + global_batch_size, + mean_rgb=self.train_mean, + stddev_rgb=self.train_stddev, + image_size=self.center_crop_size, + resize_size=self.resize_size, + ) return itertools.cycle(np_iter) ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir) train = split == 'train' ds = input_pipeline.create_input_iter( - split, - ds_builder, - data_rng, - global_batch_size, - self.train_mean, - self.train_stddev, - self.center_crop_size, - self.resize_size, - self.aspect_ratio_range, - self.scale_ratio_range, - train=train, - cache=not train if cache is None else cache, - repeat_final_dataset=repeat_final_dataset, - use_mixup=use_mixup, - mixup_alpha=0.2, - use_randaug=use_randaug) + split, + ds_builder, + data_rng, + global_batch_size, + self.train_mean, + self.train_stddev, + self.center_crop_size, + self.resize_size, + self.aspect_ratio_range, + self.scale_ratio_range, + train=train, + cache=not train if cache is None else cache, + repeat_final_dataset=repeat_final_dataset, + use_mixup=use_mixup, + mixup_alpha=0.2, + use_randaug=use_randaug, + ) return ds def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: + self, model_state: spec.ModelAuxiliaryState + ) -> spec.ModelAuxiliaryState: """Sync the batch statistics across replicas.""" # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics and @@ -83,13 +88,9 @@ def sync_batch_stats( return new_model_state def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + self, + rng: spec.RandomState, + ) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') if self.use_silu and self.use_gelu: @@ -102,15 +103,17 @@ def init_model_fn( act_fnc = nn.relu model = model_cls( - num_classes=self._num_classes, - act=act_fnc, - bn_init_scale=self.bn_init_scale, - dtype=jnp.float32) + num_classes=self._num_classes, + act=act_fnc, + bn_init_scale=self.bn_init_scale, + dtype=jnp.float32, + ) self._model = model input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': rng}, - jnp.ones(input_shape, model.dtype)) - model_state, params = pop(variables, "params") + variables = jax.jit(model.init)( + {'params': rng}, jnp.ones(input_shape, model.dtype) + ) + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) @@ -121,63 +124,70 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_model(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, 0), + static_broadcasted_argnums=(0,), + ) + def _eval_model( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng=rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng=rng, + update_batch_norm=False, + ) weights = batch.get('weights') return self._compute_metrics(logits, batch['targets'], weights) def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn, + ) return logits, new_model_state else: logits = self._model.apply( - variables, - augmented_and_preprocessed_input_batch['inputs'], - update_batch_norm=update_batch_norm, - mutable=False, - use_running_average_bn=use_running_average_bn) + variables, + augmented_and_preprocessed_input_batch['inputs'], + update_batch_norm=update_batch_norm, + mutable=False, + use_running_average_bn=use_running_average_bn, + ) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -186,12 +196,14 @@ def loss_fn( """ if label_batch.shape[-1] != self._num_classes: one_hot_labels = jax.nn.one_hot( - label_batch, num_classes=self._num_classes) + label_batch, num_classes=self._num_classes + ) else: one_hot_labels = label_batch smoothed_labels = optax.smooth_labels(one_hot_labels, label_smoothing) per_example_losses = -jnp.sum( - smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1) + smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1 + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -200,36 +212,37 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def _compute_metrics(self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor) -> Dict[str, spec.Tensor]: + def _compute_metrics( + self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor + ) -> Dict[str, spec.Tensor]: if weights is None: weights = jnp.ones(len(logits)) summed_loss = self.loss_fn(labels, logits, weights)['summed'] # not accuracy, but nr. of correct predictions accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights) metrics = { - 'loss': summed_loss, - 'accuracy': accuracy, + 'loss': summed_loss, + 'accuracy': accuracy, } metrics = lax.psum(metrics, axis_name='batch') return metrics - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: del global_step if model_state is not None: # Sync batch statistics across replicas before evaluating. @@ -239,13 +252,14 @@ def _eval_model_on_split(self, # We already repeat the dataset indefinitely in tf.data. if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng, - split=split, - global_batch_size=global_batch_size, - data_dir=data_dir, - cache=True, - repeat_final_dataset=True, - num_batches=num_batches) + data_rng, + split=split, + global_batch_size=global_batch_size, + data_dir=data_dir, + cache=True, + repeat_final_dataset=True, + num_batches=num_batches, + ) eval_metrics = {} for bi in range(num_batches): @@ -253,22 +267,21 @@ def _eval_model_on_split(self, step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - step_eval_rngs) + synced_metrics = self._eval_model( + params, batch, model_state, step_eval_rngs + ) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), - eval_metrics) + eval_metrics = jax.tree.map( + lambda x: float(x[0] / num_examples), eval_metrics + ) return eval_metrics class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): - @property def use_silu(self) -> bool: return True @@ -283,7 +296,6 @@ def test_target_value(self) -> float: class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): - @property def use_gelu(self) -> bool: return True @@ -298,7 +310,6 @@ def test_target_value(self) -> float: class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): - @property def bn_init_scale(self) -> float: return 8.0 diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index aba9e671f..c980faa06 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -8,51 +8,55 @@ from typing import Any, Callable, List, Optional, Type, Union import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from algoperf import spec from algoperf.init_utils import pytorch_default_init -def conv3x3(in_planes: int, - out_planes: int, - stride: int = 1, - groups: int = 1, - dilation: int = 1) -> nn.Conv2d: +def conv3x3( + in_planes: int, + out_planes: int, + stride: int = 1, + groups: int = 1, + dilation: int = 1, +) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( - in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation) + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + in_planes, out_planes, kernel_size=1, stride=stride, bias=False + ) class BasicBlock(nn.Module): """ResNet block.""" + expansion: int = 1 def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - act_fnc: nn.Module = nn.ReLU(inplace=True) + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + act_fnc: nn.Module = nn.ReLU(inplace=True), ) -> None: super().__init__() if norm_layer is None: @@ -60,7 +64,7 @@ def __init__( if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. self.conv1 = conv3x3(inplanes, planes, stride) @@ -92,24 +96,25 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class Bottleneck(nn.Module): """Bottleneck ResNet block.""" + expansion: int = 4 def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - act_fnc: nn.Module = nn.ReLU(inplace=True) + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + act_fnc: nn.Module = nn.ReLU(inplace=True), ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. self.conv1 = conv1x1(inplanes, width) @@ -146,18 +151,19 @@ def forward(self, x: Tensor) -> Tensor: class ResNet(nn.Module): - - def __init__(self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 1000, - zero_init_residual: bool = True, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - act_fnc: nn.Module = nn.ReLU(inplace=True), - bn_init_scale: float = 0.) -> None: + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = True, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + act_fnc: nn.Module = nn.ReLU(inplace=True), + bn_init_scale: float = 0.0, + ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -171,37 +177,42 @@ def __init__(self, replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( - 'replace_stride_with_dilation should be None ' - f'or a 3-element tuple, got {replace_stride_with_dilation}') + 'replace_stride_with_dilation should be None ' + f'or a 3-element tuple, got {replace_stride_with_dilation}' + ) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, self.act_fnc, 64, layers[0]) self.layer2 = self._make_layer( - block, - self.act_fnc, - 128, - layers[1], - stride=2, - dilate=replace_stride_with_dilation[0]) + block, + self.act_fnc, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + ) self.layer3 = self._make_layer( - block, - self.act_fnc, - 256, - layers[2], - stride=2, - dilate=replace_stride_with_dilation[1]) + block, + self.act_fnc, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + ) self.layer4 = self._make_layer( - block, - self.act_fnc, - 512, - layers[3], - stride=2, - dilate=replace_stride_with_dilation[2]) + block, + self.act_fnc, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) @@ -212,7 +223,7 @@ def __init__(self, nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.normal_(self.fc.weight, std=1e-2) - nn.init.constant_(self.fc.bias, 0.) + nn.init.constant_(self.fc.bias, 0.0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, @@ -226,13 +237,15 @@ def __init__(self, elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, bn_init_scale) - def _make_layer(self, - block: Type[Union[BasicBlock, Bottleneck]], - act_fnc: nn.Module, - planes: int, - blocks: int, - stride: int = 1, - dilate: bool = False) -> nn.Sequential: + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + act_fnc: nn.Module, + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -241,34 +254,41 @@ def _make_layer(self, stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = torch.nn.Sequential( - collections.OrderedDict([ - ("conv", conv1x1(self.inplanes, planes * block.expansion, - stride)), - ("bn", norm_layer(planes * block.expansion)), - ])) + collections.OrderedDict( + [ + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ('bn', norm_layer(planes * block.expansion)), + ] + ) + ) layers = [] layers.append( - block(self.inplanes, - planes, - stride, - downsample, - self.groups, - self.base_width, - previous_dilation, - norm_layer, - act_fnc)) + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + act_fnc, + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer, - act_fnc=act_fnc)) + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + act_fnc=act_fnc, + ) + ) return nn.Sequential(*layers) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py index c7a98e77a..1c5c0d952 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py @@ -11,8 +11,8 @@ import PIL import torch from torch import Tensor -from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F from algoperf import spec @@ -24,8 +24,8 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: # Double the pad size to match Jax implementation. pad_size = pad_size * 2 - x0 = int(max(0, x0 - pad_size / 2.)) - y0 = int(max(0, y0 - pad_size / 2.)) + x0 = int(max(0, x0 - pad_size / 2.0)) + y0 = int(max(0, y0 - pad_size / 2.0)) x1 = int(min(image_width, x0 + pad_size)) y1 = int(min(image_height, y0 + pad_size)) xy = (x0, y0, x1, y1) @@ -36,7 +36,7 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: def solarize(img: spec.Tensor, threshold: float) -> spec.Tensor: img = np.array(img) - new_img = np.where(img < threshold, img, 255. - img) + new_img = np.where(img < threshold, img, 255.0 - img) return PIL.Image.fromarray(new_img.astype(np.uint8)) @@ -49,54 +49,56 @@ def solarize_add(img: spec.Tensor, addition: int = 0) -> spec.Tensor: return PIL.Image.fromarray(new_img) -def _apply_op(img: spec.Tensor, - op_name: str, - magnitude: float, - interpolation: InterpolationMode, - fill: Optional[List[float]]) -> spec.Tensor: +def _apply_op( + img: spec.Tensor, + op_name: str, + magnitude: float, + interpolation: InterpolationMode, + fill: Optional[List[float]], +) -> spec.Tensor: if op_name == 'ShearX': # Magnitude should be arctan(magnitude). img = F.affine( - img, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(math.atan(magnitude)), 0.0], - interpolation=interpolation, - fill=fill, - center=[0, 0], + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(math.atan(magnitude)), 0.0], + interpolation=interpolation, + fill=fill, + center=[0, 0], ) elif op_name == 'ShearY': # Magnitude should be arctan(magnitude). img = F.affine( - img, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(math.atan(magnitude))], - interpolation=interpolation, - fill=fill, - center=[0, 0], + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(math.atan(magnitude))], + interpolation=interpolation, + fill=fill, + center=[0, 0], ) elif op_name == 'TranslateX': img = F.affine( - img, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - interpolation=interpolation, - shear=[0.0, 0.0], - fill=fill, + img, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, ) elif op_name == 'TranslateY': img = F.affine( - img, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - interpolation=interpolation, - shear=[0.0, 0.0], - fill=fill, + img, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, ) elif op_name == 'Rotate': img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) @@ -131,33 +133,32 @@ def _apply_op(img: spec.Tensor, def ops_space() -> Dict[str, Tuple[spec.Tensor, bool]]: return { - # op_name: (magnitudes, signed) - 'ShearX': (torch.tensor(0.3), True), - 'ShearY': (torch.tensor(0.3), True), - 'TranslateX': (torch.tensor(100), True), - 'TranslateY': (torch.tensor(100), True), - 'Rotate': (torch.tensor(30), True), - 'Brightness': (torch.tensor(1.9), False), - 'Color': (torch.tensor(1.9), False), - 'Contrast': (torch.tensor(1.9), False), - 'Sharpness': (torch.tensor(1.9), False), - 'Posterize': (torch.tensor(4), False), - 'Solarize': (torch.tensor(256), False), - 'SolarizeAdd': (torch.tensor(110), False), - 'AutoContrast': (torch.tensor(0.0), False), - 'Equalize': (torch.tensor(0.0), False), - 'Invert': (torch.tensor(0.0), False), - 'Cutout': (torch.tensor(40.0), False), + # op_name: (magnitudes, signed) + 'ShearX': (torch.tensor(0.3), True), + 'ShearY': (torch.tensor(0.3), True), + 'TranslateX': (torch.tensor(100), True), + 'TranslateY': (torch.tensor(100), True), + 'Rotate': (torch.tensor(30), True), + 'Brightness': (torch.tensor(1.9), False), + 'Color': (torch.tensor(1.9), False), + 'Contrast': (torch.tensor(1.9), False), + 'Sharpness': (torch.tensor(1.9), False), + 'Posterize': (torch.tensor(4), False), + 'Solarize': (torch.tensor(256), False), + 'SolarizeAdd': (torch.tensor(110), False), + 'AutoContrast': (torch.tensor(0.0), False), + 'Equalize': (torch.tensor(0.0), False), + 'Invert': (torch.tensor(0.0), False), + 'Cutout': (torch.tensor(40.0), False), } class RandAugment(torch.nn.Module): - def __init__( - self, - num_ops: int = 2, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, + self, + num_ops: int = 2, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, ) -> None: super().__init__() self.num_ops = num_ops @@ -183,5 +184,6 @@ def forward(self, img: spec.Tensor) -> spec.Tensor: # With 50% prob turn the magnitude negative. magnitude *= -1.0 img = _apply_op( - img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + img, op_name, magnitude, interpolation=self.interpolation, fill=fill + ) return img diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ed29271f3..85a35dc45 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -16,22 +16,21 @@ from torchvision import transforms from torchvision.datasets.folder import ImageFolder -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec import algoperf.random_utils as prng +from algoperf import data_utils, param_utils, pytorch_utils, spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import ( + BaseImagenetResNetWorkload, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() def imagenet_v2_to_torch( - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + batch: Dict[str, spec.Tensor], +) -> Dict[str, spec.Tensor]: # Slice off the part of the batch for this device and then transpose from # [N, H, W, C] to [N, C, H, W]. Only transfer the inputs to GPU. new_batch = {} @@ -48,7 +47,6 @@ def imagenet_v2_to_torch( class ImagenetResNetWorkload(BaseImagenetResNetWorkload): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Is set in submission_runner.py for workloads with PyTorch evaluation @@ -59,7 +57,8 @@ def __init__(self, *args, **kwargs) -> None: def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -67,60 +66,68 @@ def eval_num_workers(self, eval_num_workers: int): self._eval_num_workers = eval_num_workers def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache del repeat_final_dataset if split == 'test': np_iter = imagenet_v2.get_imagenet_v2_iter( - data_dir, - global_batch_size, - mean_rgb=self.train_mean, - stddev_rgb=self.train_stddev, - image_size=self.center_crop_size, - resize_size=self.resize_size) + data_dir, + global_batch_size, + mean_rgb=self.train_mean, + stddev_rgb=self.train_stddev, + image_size=self.center_crop_size, + resize_size=self.resize_size, + ) return map(imagenet_v2_to_torch, itertools.cycle(np_iter)) is_train = split == 'train' normalize = transforms.Normalize( - mean=[i / 255. for i in self.train_mean], - std=[i / 255. for i in self.train_stddev]) + mean=[i / 255.0 for i in self.train_mean], + std=[i / 255.0 for i in self.train_stddev], + ) if is_train: transform_config = [ - transforms.RandomResizedCrop( - self.center_crop_size, - scale=self.scale_ratio_range, - ratio=self.aspect_ratio_range), - transforms.RandomHorizontalFlip(), + transforms.RandomResizedCrop( + self.center_crop_size, + scale=self.scale_ratio_range, + ratio=self.aspect_ratio_range, + ), + transforms.RandomHorizontalFlip(), ] if use_randaug: transform_config.append(randaugment.RandAugment()) transform_config.extend([transforms.ToTensor(), normalize]) transform_config = transforms.Compose(transform_config) else: - transform_config = transforms.Compose([ + transform_config = transforms.Compose( + [ transforms.Resize(self.resize_size), transforms.CenterCrop(self.center_crop_size), transforms.ToTensor(), normalize, - ]) + ] + ) folder = 'train' if 'train' in split else 'val' dataset = ImageFolder( - os.path.join(data_dir, folder), transform=transform_config) + os.path.join(data_dir, folder), transform=transform_config + ) if split == 'eval_train': indices = list(range(self.num_train_examples)) random.Random(int(data_rng[0])).shuffle(indices) - dataset = torch.utils.data.Subset(dataset, - indices[:self.num_eval_train_examples]) + dataset = torch.utils.data.Subset( + dataset, indices[: self.num_eval_train_examples] + ) sampler = None if USE_PYTORCH_DDP: @@ -131,37 +138,34 @@ def _build_dataset( if USE_PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True + ) else: sampler = data_utils.DistributedEvalSampler( - dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False + ) dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, - pin_memory=True, - drop_last=is_train, - persistent_workers=is_train) + dataset, + batch_size=ds_iter_batch_size, + shuffle=not USE_PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=4 if is_train else self.eval_num_workers, + pin_memory=True, + drop_last=is_train, + persistent_workers=is_train, + ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle( - dataloader, - custom_sampler=USE_PYTORCH_DDP, - use_mixup=use_mixup, - mixup_alpha=0.2) + dataloader, + custom_sampler=USE_PYTORCH_DDP, + use_mixup=use_mixup, + mixup_alpha=0.2, + ) return dataloader - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.use_silu and self.use_gelu: @@ -188,34 +192,40 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['fc.weight', 'fc.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = 0.0, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng + del dropout_rate model = params if mode == spec.ForwardPassMode.EVAL: if update_batch_norm: raise ValueError( - 'Batch norm statistics cannot be updated during evaluation.') + 'Batch norm statistics cannot be updated during evaluation.' + ) model.eval() if mode == spec.ForwardPassMode.TRAIN: model.train() model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm, + ) + ) contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): @@ -226,11 +236,12 @@ def model_fn( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -238,10 +249,11 @@ def loss_fn( (not synced across devices). """ per_example_losses = F.cross_entropy( - logits_batch, - label_batch, - reduction='none', - label_smoothing=label_smoothing) + logits_batch, + label_batch, + reduction='none', + label_smoothing=label_smoothing, + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -250,15 +262,14 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } - def _compute_metrics(self, - logits: spec.Tensor, - labels: spec.Tensor, - weights: spec.Tensor) -> Dict[str, spec.Tensor]: + def _compute_metrics( + self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor + ) -> Dict[str, spec.Tensor]: """Return the mean accuracy and loss as a dict.""" if weights is None: weights = torch.ones(len(logits), device=DEVICE) @@ -268,15 +279,17 @@ def _compute_metrics(self, summed_loss = self.loss_fn(labels, logits, weights)['summed'] return {'accuracy': accuracy, 'loss': summed_loss} - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) @@ -284,31 +297,33 @@ def _eval_model_on_split(self, is_test = split == 'test' # These iterators repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - data_rng, - split=split, - global_batch_size=global_batch_size, - data_dir=data_dir, - cache=is_test, - repeat_final_dataset=is_test) + data_rng, + split=split, + global_batch_size=global_batch_size, + data_dir=data_dir, + cache=is_test, + repeat_final_dataset=is_test, + ) total_metrics = { - 'accuracy': torch.tensor(0., device=DEVICE), - 'loss': torch.tensor(0., device=DEVICE), + 'accuracy': torch.tensor(0.0, device=DEVICE), + 'loss': torch.tensor(0.0, device=DEVICE), } num_batches = int(math.ceil(num_examples / global_batch_size)) for _ in range(num_batches): batch = next(self._eval_iters[split]) logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False, + ) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): @@ -317,7 +332,6 @@ def _eval_model_on_split(self, class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): - @property def use_silu(self) -> bool: return True @@ -332,7 +346,6 @@ def test_target_value(self) -> float: class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): - @property def use_gelu(self) -> bool: return True @@ -347,7 +360,6 @@ def test_target_value(self) -> float: class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): - @property def bn_init_scale(self) -> float: return 8.0 diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index 84d364586..7a8e38f02 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -8,37 +8,38 @@ import tensorflow_datasets as tfds -from algoperf import data_utils -from algoperf import spec +from algoperf import data_utils, spec from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline -def get_imagenet_v2_iter(data_dir: str, - global_batch_size: int, - mean_rgb: Tuple[float, float, float], - stddev_rgb: Tuple[float, float, float], - image_size: int, - resize_size: int) -> Iterator[Dict[str, spec.Tensor]]: +def get_imagenet_v2_iter( + data_dir: str, + global_batch_size: int, + mean_rgb: Tuple[float, float, float], + stddev_rgb: Tuple[float, float, float], + image_size: int, + resize_size: int, +) -> Iterator[Dict[str, spec.Tensor]]: """Always caches and repeats indefinitely.""" ds = tfds.load( - 'imagenet_v2/matched-frequency:3.0.0', - split='test', - data_dir=data_dir, - decoders={ - 'image': tfds.decode.SkipDecoding(), - }) + 'imagenet_v2/matched-frequency:3.0.0', + split='test', + data_dir=data_dir, + decoders={ + 'image': tfds.decode.SkipDecoding(), + }, + ) def _decode_example(example: Dict[str, float]) -> Dict[str, float]: - image = input_pipeline.preprocess_for_eval(example['image'], - mean_rgb, - stddev_rgb, - image_size, - resize_size) + image = input_pipeline.preprocess_for_eval( + example['image'], mean_rgb, stddev_rgb, image_size, resize_size + ) return {'inputs': image, 'targets': example['label']} ds = ds.map(_decode_example, num_parallel_calls=16) ds = ds.batch(global_batch_size) shard_pad_fn = functools.partial( - data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size) + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ) it = map(shard_pad_fn, iter(ds)) return it diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index 83fe97108..ef696e328 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -7,7 +7,6 @@ class BaseImagenetResNetWorkload(spec.Workload): - _num_classes: int = 1000 @property @@ -15,8 +14,9 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'accuracy' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/accuracy'] > self.validation_target_value @property @@ -58,8 +58,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -109,38 +110,37 @@ def eval_period_time_sec(self) -> int: return 510 # 8.5 minutes. def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: raise NotImplementedError def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches if split == 'test': if not cache: raise ValueError('cache must be True for split=test.') if not repeat_final_dataset: raise ValueError('repeat_final_dataset must be True for split=test.') - return self._build_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset) + return self._build_dataset( + data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset + ) @property def step_hint(self) -> int: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 7ce3a0395..5e38acd8b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -7,24 +7,29 @@ from typing import Optional, Sequence, Union -from flax import linen as nn import jax.numpy as jnp +from flax import linen as nn from algoperf import spec +from algoperf.jax_utils import Dropout +DROPOUT_RATE = 0.0 -def posemb_sincos_2d(h: int, - w: int, - width: int, - temperature: int = 10_000., - dtype: jnp.dtype = jnp.float32) -> spec.Tensor: + +def posemb_sincos_2d( + h: int, + w: int, + width: int, + temperature: int = 10_000.0, + dtype: jnp.dtype = jnp.float32, +) -> spec.Tensor: """Follows the MoCo v3 logic.""" - y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence + y, x = jnp.mgrid[:h, :w] # pylint: disable=unpacking-non-sequence if width % 4 != 0: raise ValueError('Width must be mult of 4 for sincos posemb.') omega = jnp.arange(width // 4) / (width // 4 - 1) - omega = 1. / (temperature**omega) + omega = 1.0 / (temperature**omega) y = jnp.einsum('m,d->md', y.flatten(), omega) x = jnp.einsum('m,d->md', x.flatten(), omega) pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) @@ -33,16 +38,19 @@ def posemb_sincos_2d(h: int, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False - dropout_rate: float = 0.0 + dropout_rate: float = DROPOUT_RATE @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate=DROPOUT_RATE + ) -> spec.Tensor: """Applies Transformer MlpBlock module.""" inits = { - 'kernel_init': nn.initializers.xavier_uniform(), - 'bias_init': nn.initializers.normal(stddev=1e-6), + 'kernel_init': nn.initializers.xavier_uniform(), + 'bias_init': nn.initializers.normal(stddev=1e-6), } d = x.shape[2] @@ -53,13 +61,14 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y - x = nn.Dropout(rate=self.dropout_rate)(x, train) + x = Dropout(dropout_rate)(x, train, rate=dropout_rate) x = nn.Dense(d, **inits)(x) return x class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 use_glu: bool = False @@ -67,45 +76,46 @@ class Encoder1DBlock(nn.Module): dropout_rate: float = 0.0 @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate + ) -> spec.Tensor: if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1', + )(y) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' + )(y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y else: y = x y = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1', + )(y) + y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) y = x y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + name='MlpBlock_3', + dropout_rate=dropout_rate, + )(y, train, dropout_rate=dropout_rate) + y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) @@ -114,25 +124,26 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" + depth: int mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 - dropout_rate: float = 0.0 use_glu: bool = False use_post_layer_norm: bool = False @nn.compact - def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE + ) -> spec.Tensor: # Input Encoder for lyr in range(self.depth): - block = Encoder1DBlock( - name=f'encoderblock_{lyr}', - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=self.dropout_rate) - x = block(x, train) + x = Encoder1DBlock( + name=f'encoderblock_{lyr}', + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: return nn.LayerNorm(name='encoder_layernorm')(x) else: @@ -141,24 +152,27 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: class MAPHead(nn.Module): """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + dropout_rate: float = 0.0 @nn.compact - def __call__(self, x): + def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape - probe = self.param('probe', - nn.initializers.xavier_uniform(), (1, 1, d), - x.dtype) + probe = self.param( + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype + ) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(probe, x) + num_heads=self.num_heads, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(probe, x) y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) return x[:, 0] @@ -172,29 +186,30 @@ class ViT(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 rep_size: Union[int, bool] = True - dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. + dropout_rate: [float] = DROPOUT_RATE reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True use_glu: bool = False use_post_layer_norm: bool = False use_map: bool = False - def get_posemb(self, - seqshape: tuple, - width: int, - dtype: jnp.dtype = jnp.float32) -> spec.Tensor: + def get_posemb( + self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 + ) -> spec.Tensor: return posemb_sincos_2d(*seqshape, width, dtype=dtype) @nn.compact - def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: + def __call__( + self, x: spec.Tensor, *, train: bool = False, dropout_rate=DROPOUT_RATE + ) -> spec.Tensor: # Patch extraction x = nn.Conv( - self.width, - self.patch_size, - strides=self.patch_size, - padding='VALID', - name='conv_patch_extract')( - x) + self.width, + self.patch_size, + strides=self.patch_size, + padding='VALID', + name='conv_patch_extract', + )(x) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) @@ -202,23 +217,23 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: # Add posemb before adding extra token. x = x + self.get_posemb((h, w), c, x.dtype) - dropout_rate = self.dropout_rate - if dropout_rate is None: - dropout_rate = 0.0 - x = nn.Dropout(rate=dropout_rate)(x, not train) + x = Dropout(dropout_rate)(x, not train, rate=dropout_rate) x = Encoder( - depth=self.depth, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate, - name='Transformer')( - x, train=not train) + depth=self.depth, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + name='Transformer', + )(x, train=not train, dropout_rate=dropout_rate) if self.use_map: - x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) + x = MAPHead( + num_heads=self.num_heads, + mlp_dim=self.mlp_dim, + dropout_rate=dropout_rate, + )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..1637a2123 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -2,47 +2,44 @@ from typing import Dict, Optional, Tuple +import jax +import jax.numpy as jnp from flax import jax_utils from flax import linen as nn from flax.core import pop -import jax -import jax.numpy as jnp -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetWorkload +from algoperf import param_utils, spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload, +) from algoperf.workloads.imagenet_vit.imagenet_jax import models -from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import decode_variant +from algoperf.workloads.imagenet_vit.workload import ( + BaseImagenetVitWorkload, + decode_variant, +) # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - - def initialized(self, key: spec.RandomState, - model: nn.Module) -> spec.ModelInitState: + def initialized( + self, key: spec.RandomState, model: nn.Module + ) -> spec.ModelInitState: input_shape = (1, 224, 224, 3) - params_rng, dropout_rng = jax.random.split(key) - variables = jax.jit( - model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape)) - model_state, params = pop(variables, "params") + params_rng, _ = jax.random.split(key) + variables = jax.jit(model.init)( + {'params': params_rng}, jnp.ones(input_shape) + ) + model_state, params = pop(variables, 'params') return params, model_state - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._model = models.ViT( - dropout_rate=dropout_rate, - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16'), + ) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -54,44 +51,54 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'head' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm + del use_running_average_bn train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train) + logits = self._model.apply( + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate, + ) return logits, None - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: model_state = None - return super()._eval_model_on_split(split, - num_examples, - global_batch_size, - params, - model_state, - rng, - data_dir, - global_step) + return super()._eval_model_on_split( + split, + num_examples, + global_batch_size, + params, + model_state, + rng, + data_dir, + global_step, + ) class ImagenetVitGluWorkload(ImagenetVitWorkload): - @property def use_glu(self) -> bool: return True @@ -106,7 +113,6 @@ def test_target_value(self) -> float: class ImagenetVitPostLNWorkload(ImagenetVitWorkload): - @property def use_post_layer_norm(self) -> bool: return True @@ -121,7 +127,6 @@ def test_target_value(self) -> float: class ImagenetVitMapWorkload(ImagenetVitWorkload): - @property def use_map(self) -> bool: return True diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fcf0992d3..fc2a3cd46 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -9,25 +9,29 @@ from typing import Any, Optional, Tuple, Union import torch -from torch import nn import torch.nn.functional as F +from torch import nn -from algoperf import init_utils -from algoperf import spec +from algoperf import init_utils, spec from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention +DROPOUT_RATE = 0.0 + -def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: +def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.0) -> spec.Tensor: """Follows the MoCo v3 logic.""" _, width, h, w = patches.shape device = patches.device - y, x = torch.meshgrid(torch.arange(h, device=device), - torch.arange(w, device=device), indexing='ij') + y, x = torch.meshgrid( + torch.arange(h, device=device), + torch.arange(w, device=device), + indexing='ij', + ) if width % 4 != 0: raise ValueError('Width must be mult of 4 for sincos posemb.') omega = torch.arange(width // 4, device=device) / (width // 4 - 1) - omega = 1. / (temperature**omega) + omega = 1.0 / (temperature**omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) @@ -38,21 +42,19 @@ class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" def __init__( - self, - width: int, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - use_glu: bool = False, - dropout_rate: float = 0.0) -> None: + self, + width: int, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, + ) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dropout_rate = dropout_rate self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') - self.dropout = nn.Dropout(self.dropout_rate) if self.use_glu: self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) @@ -70,7 +72,7 @@ def reset_parameters(self) -> None: if module.bias is not None: module.bias.data.normal_(std=1e-6) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: x = self.linear1(x) x = self.act_fnc(x) @@ -78,7 +80,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.glu_linear(x) x = x * y - x = self.dropout(x) + x = F.dropout(x, dropout_rate, training=self.training) x = self.linear2(x) return x @@ -86,17 +88,15 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__(self, - width: int, - num_heads: int = 8, - dropout_rate: float = 0.0) -> None: + def __init__(self, width: int, num_heads: int = 8) -> None: super().__init__() self.width = width self.num_heads = num_heads assert width % num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') + 'Memory dimension must be divisible by number of heads.' + ) self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim @@ -104,7 +104,6 @@ def __init__(self, self.query = nn.Linear(self.width, self.all_head_dim) self.key = nn.Linear(self.width, self.all_head_dim) self.value = nn.Linear(self.width, self.all_head_dim) - self.dropout = nn.Dropout(dropout_rate) self.out = nn.Linear(self.width, self.width) self.reset_parameters() @@ -113,14 +112,14 @@ def reset_parameters(self) -> None: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight.data) if module.bias is not None: - nn.init.constant_(module.bias.data, 0.) + nn.init.constant_(module.bias.data, 0.0) def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor: new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: mixed_query_layer = self.query(x) key_layer = self.transpose_for_scores(self.key(x)) @@ -131,7 +130,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: attention_scores = attention_scores / math.sqrt(self.head_dim) attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) + attention_probs = F.dropout(attention_probs, dropout_rate, self.training) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() @@ -144,13 +143,14 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + def __init__( + self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + ) -> None: super().__init__() self.width = width @@ -161,35 +161,32 @@ def __init__(self, self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( - width=self.width, - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dropout_rate=dropout_rate) + width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu + ) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: if not self.use_post_layer_norm: y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y else: y = x - y = self.self_attention1(y) - y = self.dropout(y) + y = self.self_attention1(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm0(x) y = x - y = self.mlp3(y) - y = self.dropout(y) + y = self.mlp3(y, dropout_rate) + y = F.dropout(y, dropout_rate, training=self.training) x = x + y x = self.layer_norm2(x) return x @@ -198,14 +195,15 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" - def __init__(self, - depth: int, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - use_glu: bool = False, - use_post_layer_norm: bool = False, - dropout_rate: float = 0.0) -> None: + def __init__( + self, + depth: int, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, + ) -> None: super().__init__() self.depth = depth @@ -215,24 +213,28 @@ def __init__(self, self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.net = nn.ModuleList([ - Encoder1DBlock(self.width, - self.mlp_dim, - self.num_heads, - self.use_glu, - self.use_post_layer_norm, - dropout_rate) for _ in range(depth) - ]) + self.net = nn.ModuleList( + [ + Encoder1DBlock( + self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + ) + for _ in range(depth) + ] + ) if not self.use_post_layer_norm: self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) else: self.encoder_norm = None - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: # Input Encoder. for block in self.net: - x = block(x) + x = block(x, dropout_rate) if not self.use_post_layer_norm: return self.encoder_norm(x) else: @@ -242,10 +244,9 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class MAPHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12): + def __init__( + self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 + ): super().__init__() self.width = width self.mlp_dim = mlp_dim @@ -255,17 +256,18 @@ def __init__(self, nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True) + self.width, num_heads=self.num_heads, self_attn=False, bias=True + ) self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x)[0] + x = self.mha(probe, x, dropout_rate=dropout_rate)[0] y = self.layer_norm(x) - x = x + self.mlp(y) + x = x + self.mlp(y, dropout_rate) return x[:, 0] @@ -277,23 +279,21 @@ class ViT(nn.Module): channels: int = 3 def __init__( - self, - num_classes: int = 1000, - patch_size: Tuple[int, int] = (16, 16), - width: int = 768, - depth: int = 12, - mlp_dim: Optional[int] = None, # Defaults to 4x input dim. - num_heads: int = 12, - rep_size: Union[int, bool] = True, - dropout_rate: Optional[float] = 0.0, - head_zeroinit: bool = True, - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, - dtype: Any = torch.float32) -> None: + self, + num_classes: int = 1000, + patch_size: Tuple[int, int] = (16, 16), + width: int = 768, + depth: int = 12, + mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + num_heads: int = 12, + rep_size: Union[int, bool] = True, + head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, + dtype: Any = torch.float32, + ) -> None: super().__init__() - if dropout_rate is None: - dropout_rate = 0.0 self.num_classes = num_classes self.patch_size = patch_size @@ -313,21 +313,21 @@ def __init__( self.pre_logits = nn.Linear(self.width, rep_size) self.conv_patch_extract = nn.Conv2d( - self.channels, - self.width, - self.patch_size, - stride=self.patch_size, - padding='valid') - self.dropout = nn.Dropout(p=dropout_rate) + self.channels, + self.width, + self.patch_size, + stride=self.patch_size, + padding='valid', + ) self.encoder = Encoder( - depth=self.depth, - width=self.width, - mlp_dim=self.mlp_dim, - num_heads=self.num_heads, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - dropout_rate=dropout_rate) + depth=self.depth, + width=self.width, + mlp_dim=self.mlp_dim, + num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + ) if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) @@ -347,15 +347,17 @@ def reset_parameters(self) -> None: if self.num_classes: if self.head_zeroinit: - nn.init.constant_(self.head.weight.data, 0.) - nn.init.constant_(self.head.bias.data, 0.) + nn.init.constant_(self.head.weight.data, 0.0) + nn.init.constant_(self.head.bias.data, 0.0) else: init_utils.pytorch_default_init(self.head) def get_posemb(self, x: spec.Tensor) -> spec.Tensor: return posemb_sincos_2d(x).type(self.dtype) - def forward(self, x: spec.Tensor) -> spec.Tensor: + def forward( + self, x: spec.Tensor, dropout_rate: float = DROPOUT_RATE + ) -> spec.Tensor: # Patch extraction. x = self.conv_patch_extract(x) @@ -367,11 +369,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2) x = x + pes - x = self.dropout(x) - x = self.encoder(x) + x = F.dropout(x, dropout_rate, training=self.training) + x = self.encoder(x, dropout_rate) if self.use_map: - x = self.map(x) + x = self.map(x, dropout_rate) else: x = torch.mean(x, dim=1) diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 97bb38515..9c6faf70b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -1,40 +1,35 @@ """ImageNet ViT workload implemented in PyTorch.""" import contextlib -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload +from algoperf import param_utils, pytorch_utils, spec +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload, +) from algoperf.workloads.imagenet_vit.imagenet_pytorch import models -from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import decode_variant +from algoperf.workloads.imagenet_vit.workload import ( + BaseImagenetVitWorkload, + decode_variant, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() # Make sure we inherit from the ViT base workload first. class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): - - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = models.ViT( - dropout_rate=dropout_rate, - num_classes=self._num_classes, - use_glu=self.use_glu, - use_post_layer_norm=self.use_post_layer_norm, - use_map=self.use_map, - **decode_variant('S/16')) + num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + **decode_variant('S/16'), + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -49,13 +44,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['head.weight', 'head.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -69,18 +66,20 @@ def model_fn( model.train() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None class ImagenetVitGluWorkload(ImagenetVitWorkload): - @property def use_glu(self) -> bool: return True @@ -95,7 +94,6 @@ def test_target_value(self) -> float: class ImagenetVitPostLNWorkload(ImagenetVitWorkload): - @property def use_post_layer_norm(self) -> bool: return True @@ -110,7 +108,6 @@ def test_target_value(self) -> float: class ImagenetVitMapWorkload(ImagenetVitWorkload): - @property def use_map(self) -> bool: return True diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index f249ddee8..2a0070ba4 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -3,8 +3,9 @@ from typing import Dict, Iterator, Optional from algoperf import spec -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import ( + BaseImagenetResNetWorkload, +) def decode_variant(variant: str) -> Dict[str, int]: @@ -12,46 +13,52 @@ def decode_variant(variant: str) -> Dict[str, int]: v, patch = variant.split('/') return { - # Reference: Table 2 of https://arxiv.org/abs/2106.04560. - 'width': { - 'Ti': 192, - 'S': 384, - 'M': 512, - 'B': 768, - 'L': 1024, - 'H': 1280, - 'g': 1408, - 'G': 1664, - }[v], - 'depth': { - 'Ti': 12, - 'S': 12, - 'M': 12, - 'B': 12, - 'L': 24, - 'H': 32, - 'g': 40, - 'G': 48, - }[v], - 'mlp_dim': { - 'Ti': 768, - 'S': 1536, - 'M': 2048, - 'B': 3072, - 'L': 4096, - 'H': 5120, - 'g': 6144, - 'G': 8192, - }[v], - 'num_heads': { - 'Ti': 3, 'S': 6, 'M': 8, 'B': 12, 'L': 16, 'H': 16, 'g': 16, 'G': 16 - }[v], - 'patch_size': (int(patch), int(patch)), + # Reference: Table 2 of https://arxiv.org/abs/2106.04560. + 'width': { + 'Ti': 192, + 'S': 384, + 'M': 512, + 'B': 768, + 'L': 1024, + 'H': 1280, + 'g': 1408, + 'G': 1664, + }[v], + 'depth': { + 'Ti': 12, + 'S': 12, + 'M': 12, + 'B': 12, + 'L': 24, + 'H': 32, + 'g': 40, + 'G': 48, + }[v], + 'mlp_dim': { + 'Ti': 768, + 'S': 1536, + 'M': 2048, + 'B': 3072, + 'L': 4096, + 'H': 5120, + 'g': 6144, + 'G': 8192, + }[v], + 'num_heads': { + 'Ti': 3, + 'S': 6, + 'M': 8, + 'B': 12, + 'L': 16, + 'H': 16, + 'g': 16, + 'G': 16, + }[v], + 'patch_size': (int(patch), int(patch)), } class BaseImagenetVitWorkload(BaseImagenetResNetWorkload): - @property def validation_target_value(self) -> float: return 1 - 0.22691 # 0.77309 @@ -88,25 +95,28 @@ def eval_period_time_sec(self) -> int: return 7 * 60 # 7 mins. def _build_dataset( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - use_mixup: bool = False, - use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + use_mixup: bool = False, + use_randaug: bool = False, + ) -> Iterator[Dict[str, spec.Tensor]]: # We use mixup and Randaugment for ViT workloads. use_mixup = use_randaug = split == 'train' - return super()._build_dataset(data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset, - use_mixup, - use_randaug) + return super()._build_dataset( + data_rng, + split, + data_dir, + global_batch_size, + cache, + repeat_final_dataset, + use_mixup, + use_randaug, + ) @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py index 1310e7b59..23ce8e3b7 100644 --- a/algoperf/workloads/librispeech_conformer/input_pipeline.py +++ b/algoperf/workloads/librispeech_conformer/input_pipeline.py @@ -4,13 +4,12 @@ import csv -from absl import logging import numpy as np import torch +from absl import logging class LibriSpeechDataset(torch.utils.data.Dataset): - def __init__(self, split, data_dir): super().__init__() self.data_dir = data_dir @@ -38,13 +37,14 @@ def __getitem__(self, index): audio_paddings = np.zeros_like(audio, dtype=np.float32) audio_paddings = np.pad( - audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0) + audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0 + ) audio = np.pad(audio, (0, 320000 - audio.shape[0]), constant_values=0.0) target_paddings = np.zeros_like(targets, dtype=np.float32) target_paddings = np.pad( - target_paddings, (0, 256 - target_paddings.shape[0]), - constant_values=1.0) + target_paddings, (0, 256 - target_paddings.shape[0]), constant_values=1.0 + ) targets = np.pad(targets, (0, 256 - targets.shape[0]), constant_values=0) audio = audio.astype(np.float32) audio_paddings = audio_paddings.astype(np.float32) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py index 9f45434d9..531e68a45 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py @@ -10,185 +10,186 @@ from typing import Any, Optional, Union -from flax import linen as nn -from flax import struct import jax import jax.numpy as jnp import numpy as np +from flax import linen as nn +from flax import struct # mel spectrum constants. _MEL_BREAK_FREQUENCY_HERTZ = 700.0 _MEL_HIGH_FREQUENCY_Q = 1127.0 LIBRISPEECH_MEAN_VECTOR = [ - -7.6047816276550293, - -7.1206226348876953, - -6.8864245414733887, - -6.8705768585205078, - -6.9667720794677734, - -7.1084094047546387, - -6.9528026580810547, - -6.783994197845459, - -6.6195521354675293, - -6.4876265525817871, - -6.4120659828186035, - -6.394047737121582, - -6.4244871139526367, - -6.3993711471557617, - -6.5158271789550781, - -6.7137999534606934, - -6.8476877212524414, - -6.9885001182556152, - -6.9221386909484863, - -7.146148681640625, - -7.2040400505065918, - -7.0537552833557129, - -7.3140382766723633, - -7.1223249435424805, - -7.30251407623291, - -7.1212143898010254, - -7.2425732612609863, - -7.1730537414550781, - -7.0979413986206055, - -7.088747501373291, - -6.9849910736083984, - -6.8787732124328613, - -6.7602753639221191, - -6.6300945281982422, - -6.5145769119262695, - -6.4245057106018066, - -6.356513500213623, - -6.31787633895874, - -6.2660770416259766, - -6.2468328475952148, - -6.2821526527404785, - -6.1908388137817383, - -6.2484354972839355, - -6.1472640037536621, - -6.0924725532531738, - -6.0171003341674805, - -5.9250402450561523, - -5.8535833358764648, - -5.8209109306335449, - -5.8118929862976074, - -5.80783748626709, - -5.7714629173278809, - -5.7453732490539551, - -5.7705655097961426, - -5.7765641212463379, - -5.7831673622131348, - -5.7954087257385254, - -5.7994823455810547, - -5.8023476600646973, - -5.8047118186950684, - -5.8168182373046875, - -5.8844799995422363, - -5.9727106094360352, - -6.0444660186767578, - -6.1284866333007812, - -6.2257585525512695, - -6.3157496452331543, - -6.39061164855957, - -6.4928598403930664, - -6.5498456954956055, - -6.6054320335388184, - -6.6508378982543945, - -6.66917610168457, - -6.6726889610290527, - -6.684234619140625, - -6.6974577903747559, - -6.75471830368042, - -6.7949142456054688, - -6.8634209632873535, - -6.94186544418335 + -7.6047816276550293, + -7.1206226348876953, + -6.8864245414733887, + -6.8705768585205078, + -6.9667720794677734, + -7.1084094047546387, + -6.9528026580810547, + -6.783994197845459, + -6.6195521354675293, + -6.4876265525817871, + -6.4120659828186035, + -6.394047737121582, + -6.4244871139526367, + -6.3993711471557617, + -6.5158271789550781, + -6.7137999534606934, + -6.8476877212524414, + -6.9885001182556152, + -6.9221386909484863, + -7.146148681640625, + -7.2040400505065918, + -7.0537552833557129, + -7.3140382766723633, + -7.1223249435424805, + -7.30251407623291, + -7.1212143898010254, + -7.2425732612609863, + -7.1730537414550781, + -7.0979413986206055, + -7.088747501373291, + -6.9849910736083984, + -6.8787732124328613, + -6.7602753639221191, + -6.6300945281982422, + -6.5145769119262695, + -6.4245057106018066, + -6.356513500213623, + -6.31787633895874, + -6.2660770416259766, + -6.2468328475952148, + -6.2821526527404785, + -6.1908388137817383, + -6.2484354972839355, + -6.1472640037536621, + -6.0924725532531738, + -6.0171003341674805, + -5.9250402450561523, + -5.8535833358764648, + -5.8209109306335449, + -5.8118929862976074, + -5.80783748626709, + -5.7714629173278809, + -5.7453732490539551, + -5.7705655097961426, + -5.7765641212463379, + -5.7831673622131348, + -5.7954087257385254, + -5.7994823455810547, + -5.8023476600646973, + -5.8047118186950684, + -5.8168182373046875, + -5.8844799995422363, + -5.9727106094360352, + -6.0444660186767578, + -6.1284866333007812, + -6.2257585525512695, + -6.3157496452331543, + -6.39061164855957, + -6.4928598403930664, + -6.5498456954956055, + -6.6054320335388184, + -6.6508378982543945, + -6.66917610168457, + -6.6726889610290527, + -6.684234619140625, + -6.6974577903747559, + -6.75471830368042, + -6.7949142456054688, + -6.8634209632873535, + -6.94186544418335, ] LIBRISPEECH_STD_VECTOR = [ - 3.4353282451629639, - 3.5962932109832764, - 3.7012472152709961, - 3.7369205951690674, - 3.7535104751586914, - 3.693629264831543, - 3.6922497749328613, - 3.7641522884368896, - 3.8419716358184814, - 3.8999848365783691, - 3.9294240474700928, - 3.9317409992218018, - 3.9139585494995117, - 3.9031598567962646, - 3.8691999912261963, - 3.8155081272125244, - 3.7644970417022705, - 3.7099106311798096, - 3.6965086460113525, - 3.6003766059875488, - 3.5493226051330566, - 3.5465121269226074, - 3.45003604888916, - 3.4712812900543213, - 3.4084610939025879, - 3.4408135414123535, - 3.4104881286621094, - 3.4217638969421387, - 3.4312851428985596, - 3.4199209213256836, - 3.4305806159973145, - 3.4382665157318115, - 3.4580366611480713, - 3.4817991256713867, - 3.4958710670471191, - 3.5036792755126953, - 3.5047574043273926, - 3.4988734722137451, - 3.493056058883667, - 3.4822943210601807, - 3.459430456161499, - 3.4612770080566406, - 3.4559063911437988, - 3.4755423069000244, - 3.4971549510955811, - 3.5326557159423828, - 3.5705199241638184, - 3.5920312404632568, - 3.596907377243042, - 3.5913500785827637, - 3.5865931510925293, - 3.5826809406280518, - 3.5837743282318115, - 3.5895791053771973, - 3.5819313526153564, - 3.5837869644165039, - 3.5861184597015381, - 3.5889589786529541, - 3.592214822769165, - 3.5939455032348633, - 3.5856630802154541, - 3.5884113311767578, - 3.5921022891998291, - 3.5870490074157715, - 3.5806570053100586, - 3.5731067657470703, - 3.5617532730102539, - 3.54980731010437, - 3.5527374744415283, - 3.5475366115570068, - 3.5387849807739258, - 3.5256178379058838, - 3.5031836032867432, - 3.4922726154327393, - 3.4879646301269531, - 3.4725594520568848, - 3.4558389186859131, - 3.4351828098297119, - 3.4284293651580811, - 3.4299170970916748 + 3.4353282451629639, + 3.5962932109832764, + 3.7012472152709961, + 3.7369205951690674, + 3.7535104751586914, + 3.693629264831543, + 3.6922497749328613, + 3.7641522884368896, + 3.8419716358184814, + 3.8999848365783691, + 3.9294240474700928, + 3.9317409992218018, + 3.9139585494995117, + 3.9031598567962646, + 3.8691999912261963, + 3.8155081272125244, + 3.7644970417022705, + 3.7099106311798096, + 3.6965086460113525, + 3.6003766059875488, + 3.5493226051330566, + 3.5465121269226074, + 3.45003604888916, + 3.4712812900543213, + 3.4084610939025879, + 3.4408135414123535, + 3.4104881286621094, + 3.4217638969421387, + 3.4312851428985596, + 3.4199209213256836, + 3.4305806159973145, + 3.4382665157318115, + 3.4580366611480713, + 3.4817991256713867, + 3.4958710670471191, + 3.5036792755126953, + 3.5047574043273926, + 3.4988734722137451, + 3.493056058883667, + 3.4822943210601807, + 3.459430456161499, + 3.4612770080566406, + 3.4559063911437988, + 3.4755423069000244, + 3.4971549510955811, + 3.5326557159423828, + 3.5705199241638184, + 3.5920312404632568, + 3.596907377243042, + 3.5913500785827637, + 3.5865931510925293, + 3.5826809406280518, + 3.5837743282318115, + 3.5895791053771973, + 3.5819313526153564, + 3.5837869644165039, + 3.5861184597015381, + 3.5889589786529541, + 3.592214822769165, + 3.5939455032348633, + 3.5856630802154541, + 3.5884113311767578, + 3.5921022891998291, + 3.5870490074157715, + 3.5806570053100586, + 3.5731067657470703, + 3.5617532730102539, + 3.54980731010437, + 3.5527374744415283, + 3.5475366115570068, + 3.5387849807739258, + 3.5256178379058838, + 3.5031836032867432, + 3.4922726154327393, + 3.4879646301269531, + 3.4725594520568848, + 3.4558389186859131, + 3.4351828098297119, + 3.4284293651580811, + 3.4299170970916748, ] @struct.dataclass class LibrispeechPreprocessingConfig: """Config to hold all preprocessing options for librispeech dataset.""" + sample_rate: float = 16000.0 frame_size_ms: float = 25.0 frame_step_ms: float = 10.0 @@ -208,8 +209,9 @@ class LibrispeechPreprocessingConfig: def _hertz_to_mel(frequencies_hertz): """Convert hertz to mel.""" - return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz / - _MEL_BREAK_FREQUENCY_HERTZ)) + return _MEL_HIGH_FREQUENCY_Q * jnp.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ) + ) def _pad_end_length(num_timesteps, frame_step, frame_size): @@ -221,11 +223,13 @@ def _pad_end_length(num_timesteps, frame_step, frame_size): return padded_length - num_timesteps -def frame(x, - frame_length: int, - frame_step: int, - pad_end: bool = False, - pad_value: Union[int, float] = 0.0): +def frame( + x, + frame_length: int, + frame_step: int, + pad_end: bool = False, + pad_value: Union[int, float] = 0.0, +): """Slides a window and extract values. This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with @@ -251,24 +255,31 @@ def frame(x, if pad_end: num_extends = _pad_end_length(num_timesteps, frame_step, frame_length) x = jnp.pad( - x, ((0, 0), (0, num_extends), (0, 0)), - 'constant', - constant_values=pad_value) + x, + ((0, 0), (0, num_extends), (0, 0)), + 'constant', + constant_values=pad_value, + ) flat_y = jax.lax.conv_general_dilated_patches( - x, (frame_length,), (frame_step,), - 'VALID', - dimension_numbers=('NTC', 'OIT', 'NTC')) + x, + (frame_length,), + (frame_step,), + 'VALID', + dimension_numbers=('NTC', 'OIT', 'NTC'), + ) ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length)) return ret.transpose((0, 1, 3, 2)) -def linear_to_mel_weight_matrix(num_mel_bins: int = 20, - num_spectrogram_bins: int = 129, - sample_rate: Union[int, float] = 8000, - lower_edge_hertz: Union[int, float] = 125.0, - upper_edge_hertz: Union[int, float] = 3800.0, - dtype: Any = jnp.float32): +def linear_to_mel_weight_matrix( + num_mel_bins: int = 20, + num_spectrogram_bins: int = 129, + sample_rate: Union[int, float] = 8000, + lower_edge_hertz: Union[int, float] = 125.0, + upper_edge_hertz: Union[int, float] = 3800.0, + dtype: Any = jnp.float32, +): r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`. Args: @@ -300,23 +311,29 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, if num_mel_bins <= 0: raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) if lower_edge_hertz < 0.0: - raise ValueError('lower_edge_hertz must be non-negative. Got: %s' % - lower_edge_hertz) + raise ValueError( + 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz + ) if lower_edge_hertz >= upper_edge_hertz: - raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % - (lower_edge_hertz, upper_edge_hertz)) + raise ValueError( + 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f' + % (lower_edge_hertz, upper_edge_hertz) + ) if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: - raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' - 'frequency (sample_rate / 2). Got %s for sample_rate: %s' % - (upper_edge_hertz, sample_rate)) + raise ValueError( + 'upper_edge_hertz must not be larger than the Nyquist ' + 'frequency (sample_rate / 2). Got %s for sample_rate: %s' + % (upper_edge_hertz, sample_rate) + ) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = jnp.linspace( - 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:] + 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype + )[bands_to_zero:] spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis] # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The @@ -324,10 +341,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into # num_mel_bins + 2 pieces. edges = jnp.linspace( - _hertz_to_mel(lower_edge_hertz), - _hertz_to_mel(upper_edge_hertz), - num_mel_bins + 2, - dtype=dtype) + _hertz_to_mel(lower_edge_hertz), + _hertz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + dtype=dtype, + ) # Split the triples up and reshape them into [1, num_mel_bins] tensors. lower_edge_mel = edges[:-2][jnp.newaxis, :] @@ -337,9 +355,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Calculate lower and upper slopes for every spectrogram bin. # Line segments are linear in the mel domain, not Hertz. lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( - center_mel - lower_edge_mel) + center_mel - lower_edge_mel + ) upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( - upper_edge_mel - center_mel) + upper_edge_mel - center_mel + ) # Intersect the line segments with each other and zero. mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes)) @@ -366,23 +386,26 @@ def _hanning_greco(win_support, frame_size, dtype): """ if frame_size < win_support: raise ValueError( - 'Provided frame_size = {} is lower than win_support = {}'.format( - frame_size, win_support)) + 'Provided frame_size = {} is lower than win_support = {}'.format( + frame_size, win_support + ) + ) arg = jnp.pi * 2.0 / (win_support) - hann = 0.5 - (0.5 * jnp.cos(arg * - (jnp.arange(win_support, dtype=dtype) + 0.5))) + hann = 0.5 - ( + 0.5 * jnp.cos(arg * (jnp.arange(win_support, dtype=dtype) + 0.5)) + ) zero_size = frame_size - win_support return jnp.pad(hann, [(0, zero_size)]) def _next_pow_of_two(x: Union[int, float]) -> int: - return int(2**np.ceil(np.log2(x))) + return int(2 ** np.ceil(np.log2(x))) class SpectrogramFrontend(nn.Module): - """Layer to convert input audio signals from time domain to frequency domain. - """ + """Layer to convert input audio signals from time domain to frequency domain.""" + config: LibrispeechPreprocessingConfig = None input_scale_factor: float = 1.0 output_log: bool = False @@ -390,8 +413,9 @@ class SpectrogramFrontend(nn.Module): def setup(self) -> None: p = self.config self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0)) - self._frame_size = int(round( - p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph + self._frame_size = ( + int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1 + ) # +1 for the preemph # TF-version has maximum of 512, but it's not always necessary self.fft_size = _next_pow_of_two(self._frame_size) @@ -421,32 +445,39 @@ def f(frame_size, dtype): def _apply_preemphasis(self, framed_signal): p = self.config if p.preemph_htk_flavor: - return jnp.concatenate([ - framed_signal[:, :, :1, :] * (1. - p.preemph), - (framed_signal[:, :, 1:-1, :] - - p.preemph * framed_signal[:, :, :-2, :]) - ], - axis=2) + return jnp.concatenate( + [ + framed_signal[:, :, :1, :] * (1.0 - p.preemph), + ( + framed_signal[:, :, 1:-1, :] + - p.preemph * framed_signal[:, :, :-2, :] + ), + ], + axis=2, + ) else: - return (framed_signal[:, :, 1:, :] - - p.preemph * framed_signal[:, :, :-1, :]) + return ( + framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :] + ) def fprop_paddings(self, input_paddings): p = self.config if p.pad_end: - num_extends = _pad_end_length(input_paddings.shape[1], - self._frame_step, - self._frame_size) + num_extends = _pad_end_length( + input_paddings.shape[1], self._frame_step, self._frame_size + ) input_paddings = jnp.pad( - input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0) + input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0 + ) return jax.lax.reduce_window( - input_paddings, - init_value=1.0, - computation=jax.lax.min, - window_dimensions=[1, self._frame_size], - window_strides=[1, self._frame_step], - padding='valid') + input_paddings, + init_value=1.0, + computation=jax.lax.min, + window_dimensions=[1, self._frame_size], + window_strides=[1, self._frame_step], + padding='valid', + ) def next_prng_key(self, name='dropout'): return self.make_rng(name) @@ -469,7 +500,8 @@ def __call__(self, inputs, input_paddings): pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor framed_signal = frame( - pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end) + pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end + ) if p.preemph != 0.0: preemphasized = self._apply_preemphasis(framed_signal) @@ -477,8 +509,10 @@ def __call__(self, inputs, input_paddings): preemphasized = framed_signal[..., :-1, :] if p.noise_scale > 0.0: - noise_signal = jax.random.normal(self.next_prng_key(), - preemphasized.shape) * p.noise_scale + noise_signal = ( + jax.random.normal(self.next_prng_key(), preemphasized.shape) + * p.noise_scale + ) else: noise_signal = jnp.zeros(preemphasized.shape) @@ -501,8 +535,8 @@ def __call__(self, inputs, input_paddings): class MelFilterbankFrontend(nn.Module): - """Layer to compute log mel spectograms from input audio signals. - """ + """Layer to compute log mel spectograms from input audio signals.""" + config: LibrispeechPreprocessingConfig = None use_divide_stream: bool = True per_bin_mean: Optional[float] = None @@ -513,7 +547,8 @@ def setup(self): input_scale_factor = 2**-15 if self.use_divide_stream else 1.0 self.stft = SpectrogramFrontend( - p, input_scale_factor=input_scale_factor, output_log=False) + p, input_scale_factor=input_scale_factor, output_log=False + ) if self.per_bin_mean is None: per_bin_mean = [0.0] * p.num_bins @@ -526,9 +561,11 @@ def setup(self): per_bin_stddev = self.per_bin_stddev self._normalizer_mean = jnp.array(per_bin_mean)[ - jnp.newaxis, jnp.newaxis, :, jnp.newaxis] + jnp.newaxis, jnp.newaxis, :, jnp.newaxis + ] self._normalizer_stddev = jnp.array(per_bin_stddev)[ - jnp.newaxis, jnp.newaxis, :, jnp.newaxis] + jnp.newaxis, jnp.newaxis, :, jnp.newaxis + ] @nn.compact def __call__(self, inputs, input_paddings): @@ -537,18 +574,21 @@ def __call__(self, inputs, input_paddings): spect, spect_paddings = self.stft(inputs, input_paddings) mel_weights = linear_to_mel_weight_matrix( - num_mel_bins=p.num_bins, - num_spectrogram_bins=spect.shape[2], - sample_rate=p.sample_rate, - lower_edge_hertz=p.lower_edge_hertz, - upper_edge_hertz=p.upper_edge_hertz) + num_mel_bins=p.num_bins, + num_spectrogram_bins=spect.shape[2], + sample_rate=p.sample_rate, + lower_edge_hertz=p.lower_edge_hertz, + upper_edge_hertz=p.upper_edge_hertz, + ) mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect) logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor)) normalized_logmel_spectrogram = ( - (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev) + logmel_spectrogram - self._normalizer_mean + ) / self._normalizer_stddev - normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram, - -1) + normalized_logmel_spectrogram = jnp.squeeze( + normalized_logmel_spectrogram, -1 + ) return normalized_logmel_spectrogram, spect_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index 593d463c3..9fc2e39ef 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -16,34 +16,32 @@ import math from typing import Any, List, Optional -from flax import linen as nn -from flax import struct import jax import jax.numpy as jnp import numpy as np +from flax import linen as nn +from flax import struct -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - librispeech_preprocessor as preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - spectrum_augmenter +from algoperf.jax_utils import Dropout +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + librispeech_preprocessor as preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + spectrum_augmenter, +) + +DROPOUT_RATE = 0.1 @struct.dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 dtype: Any = jnp.float32 encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 - attention_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - attention_residual_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.0. - conv_residual_dropout_rate: Optional[float] = 0.0 - feed_forward_dropout_rate: float = 0.0 - # If None, defaults to 0.1. - feed_forward_residual_dropout_rate: Optional[float] = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -53,8 +51,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -73,6 +69,7 @@ class LayerNorm(nn.Module): zeros, this differs from default flax implementation of multiplying by scale and initializing to ones. """ + dim: int = 0 epsilon: float = 1e-6 @@ -86,7 +83,7 @@ def __call__(self, inputs): var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True) normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) + normed_inputs *= 1 + self.scale normed_inputs += self.bias return normed_inputs @@ -99,39 +96,41 @@ class Subsample(nn.Module): encoder_dim: model dimension of conformer. input_dropout_rate: dropout rate for inputs. """ + encoder_dim: int = 0 - input_dropout_rate: float = 0.0 @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): output_paddings = input_paddings outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( - input_channels=1, output_channels=self.encoder_dim)( - outputs, output_paddings) + input_channels=1, output_channels=self.encoder_dim + )(outputs, output_paddings) outputs, output_paddings = Conv2dSubsampling( - input_channels=self.encoder_dim, - output_channels=self.encoder_dim)(outputs, output_paddings) + input_channels=self.encoder_dim, output_channels=self.encoder_dim + )(outputs, output_paddings) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels) + ) outputs = nn.Dense( - self.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + self.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( - seq_length=outputs.shape[1]) + seq_length=outputs.shape[1] + ) - outputs = nn.Dropout( - rate=self.input_dropout_rate, deterministic=not train)( - outputs) + outputs = Dropout(rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate + ) return outputs, output_paddings @@ -143,6 +142,7 @@ class Conv2dSubsampling(nn.Module): 2) Also performs strided convolution over input_paddings to return the correct paddings for downstream layers. """ + input_channels: int = 0 output_channels: int = 0 filter_stride: List[int] = (2, 2) @@ -150,24 +150,26 @@ class Conv2dSubsampling(nn.Module): def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', - nn.initializers.xavier_uniform(), - self.filter_shape) + self.kernel = self.param( + 'kernel', nn.initializers.xavier_uniform(), self.filter_shape + ) self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels + ) @nn.compact def __call__(self, inputs, paddings): # Computing strided convolution to subsample inputs. feature_group_count = inputs.shape[3] // self.filter_shape[2] outputs = jax.lax.conv_general_dilated( - lhs=inputs, - rhs=self.kernel, - window_strides=self.filter_stride, - padding=self.padding, - rhs_dilation=(1, 1), - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) + lhs=inputs, + rhs=self.kernel, + window_strides=self.filter_stride, + padding=self.padding, + rhs_dilation=(1, 1), + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + feature_group_count=feature_group_count, + ) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) outputs = nn.relu(outputs) @@ -178,64 +180,64 @@ def __call__(self, inputs, paddings): pad_len = (input_length + stride - 1) // stride * stride - input_length out_padding = jax.lax.conv_general_dilated( - lhs=paddings[:, :, None], - rhs=jnp.ones([1, 1, 1]), - window_strides=self.filter_stride[:1], - padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) + lhs=paddings[:, :, None], + rhs=jnp.ones([1, 1, 1]), + window_strides=self.filter_stride[:1], + padding=[(0, pad_len)], + dimension_numbers=('NHC', 'HIO', 'NHC'), + ) out_padding = jnp.squeeze(out_padding, axis=-1) # Mask outputs by correct paddings to ensure padded elements in inputs map # to padded value in outputs. - outputs = outputs * \ - (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + outputs = outputs * ( + 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1) + ) return outputs, out_padding class FeedForwardModule(nn.Module): - """Feedforward block of conformer layer. - """ + """Feedforward block of conformer layer.""" + config: ConformerConfig @nn.compact - def __call__(self, inputs, padding_mask=None, train=False): + def __call__( + self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE + ): config = self.config - inputs = LayerNorm(dim=config.encoder_dim)(inputs) inputs = nn.Dense( - config.encoder_dim * config.feed_forward_expansion_factor, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.encoder_dim * config.feed_forward_expansion_factor, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': activation_fn = nn.gelu else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}' + ) inputs = activation_fn(inputs) - inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate + ) inputs = inputs * padding_mask inputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) inputs = inputs * padding_mask - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train) return inputs @@ -247,6 +249,7 @@ class AddPositionalEmbedding(nn.Module): max_len: maximum possible length for the input posemb_init: positional embedding initializer """ + min_timescale: int = 1 max_timescale: int = 10_000 embedding_dim: int = 512 @@ -255,21 +258,23 @@ class AddPositionalEmbedding(nn.Module): def __call__(self, seq_length): position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] num_timescales = self.embedding_dim // 2 - log_timescale_increment = ( - math.log(float(self.max_timescale) / float(self.min_timescale)) / - jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)) + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1) inv_timescales = self.min_timescale * jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * - -log_timescale_increment) + jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment + ) scaled_time = ( - position[:, :, jnp.newaxis] * - inv_timescales[jnp.newaxis, jnp.newaxis, :]) - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], - axis=2).astype(jnp.float32) + position[:, :, jnp.newaxis] * inv_timescales[jnp.newaxis, jnp.newaxis, :] + ) + signal = jnp.concatenate( + [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2 + ).astype(jnp.float32) # Force usage of `np` rather than `jnp` to compute static values at trace # time. - signal = jnp.pad(signal, - [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]) + signal = jnp.pad( + signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]] + ) return signal @@ -277,6 +282,7 @@ def __call__(self, seq_length): # https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201 class QueryScaler(nn.Module): """A layer to scale individual dims of the query attention matrix.""" + dim: int = 0 def setup(self): @@ -286,8 +292,10 @@ def setup(self): def __call__(self, inputs): inputs_shape = inputs.shape if inputs_shape[-1] != self.dim: - raise ValueError('QueryScaler expects inputs to have' - ' same last dimension as scaling param.') + raise ValueError( + 'QueryScaler expects inputs to have' + ' same last dimension as scaling param.' + ) # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we # can avoid unnecessary XLA op fusion mess on TPU. @@ -302,18 +310,20 @@ def __call__(self, inputs): # Modifying flax linen default dot product attention function to add # query scaling, reference to original function here : # https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121 -def dot_product_attention(query, - key, - value, - bias=None, - mask=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - dtype=jnp.float32, - precision=None, - temperature=1.0): +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + dtype=jnp.float32, + precision=None, + temperature=1.0, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -352,29 +362,35 @@ def dot_product_attention(query, """ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') + 'q, k, v batch dims must match.' + ) assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') + 'q, k, v num_heads must match.' + ) assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights query = QueryScaler(dim=query.shape[-1])(query) attn_weights = nn.attention.dot_product_attention_weights( - query, - key, - bias, - mask, - broadcast_dropout, - dropout_rng, - dropout_rate, - deterministic, - dtype, - precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + ) # return weighted sum over values for each query position - return jnp.einsum( - '...hqk,...khd->...qhd', attn_weights, value, - precision=precision) * temperature + return ( + jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, precision=precision + ) + * temperature + ) class MultiHeadedSelfAttention(nn.Module): @@ -386,39 +402,39 @@ class MultiHeadedSelfAttention(nn.Module): Note: this attention implementation uses a learned scale parameter to scale query matrix before passing it to flax attention module. """ + config: ConformerConfig = None @nn.compact - def __call__(self, inputs, paddings, train): + def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE): config = self.config + mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( - mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) + mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32 + ) inputs = LayerNorm(dim=config.encoder_dim)(inputs) attention_fn = functools.partial( - dot_product_attention, temperature=config.attention_temperature) + dot_product_attention, temperature=config.attention_temperature + ) result = nn.MultiHeadDotProductAttention( - num_heads=config.num_attention_heads, - qkv_features=config.encoder_dim, - decode=False, - dtype=config.dtype, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros, - use_bias=True, - broadcast_dropout=False, - attention_fn=attention_fn, - dropout_rate=config.attention_dropout_rate, - deterministic=not train)( - inputs_q=inputs, mask=attention_mask) - - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - result = nn.Dropout( - rate=attention_residual_dropout_rate, deterministic=not train)( - result) + num_heads=config.num_attention_heads, + qkv_features=config.encoder_dim, + decode=False, + dtype=config.dtype, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + use_bias=True, + broadcast_dropout=False, + attention_fn=attention_fn, + dropout_rate=dropout_rate, + deterministic=not train, + )(inputs_q=inputs, mask=attention_mask) + + result = Dropout(rate=dropout_rate, deterministic=not train)( + result, rate=dropout_rate + ) return result @@ -435,30 +451,27 @@ class BatchNorm(nn.Module): and the corresponding defaults for momentum and epsilon have been copied over from lingvo. """ + config: ConformerConfig def setup(self): dim = self.config.encoder_dim dtype = self.config.dtype - self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), - dim) - self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), - dim) + self.ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim + ) + self.ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim + ) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, - inputs, - input_paddings, - update_batch_norm, - use_running_average_bn): + def __call__( + self, inputs, input_paddings, update_batch_norm, use_running_average_bn + ): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -475,23 +488,25 @@ def __call__(self, mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True + ) count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = jnp.sum( - (inputs - mean) * (inputs - mean) * mask, - axis=reduce_over_dims, - keepdims=True) + (inputs - mean) * (inputs - mean) * mask, + axis=reduce_over_dims, + keepdims=True, + ) var = sum_vv / count_v if update_batch_norm: - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var + self.ra_mean.value = ( + momentum * self.ra_mean.value + (1 - momentum) * mean + ) + self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) bn_output = (inputs - mean) * inv + self.beta @@ -520,67 +535,68 @@ class ConvolutionBlock(nn.Module): | output """ + config: ConformerConfig @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm, - use_running_average_bn): + def __call__( + self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn, + dropout_rate=DROPOUT_RATE, + ): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) input_gated1 = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)( - inputs) + config.encoder_dim, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=True, + )(inputs) input_gated2 = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)( - inputs) + config.encoder_dim, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=True, + )(inputs) inputs = input_gated1 * jax.nn.sigmoid(input_gated2) inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1)) inputs = nn.Conv( - features=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - strides=(1,), - padding='SAME', - feature_group_count=config.encoder_dim, - use_bias=False, - kernel_init=nn.initializers.xavier_uniform())( - inputs) - - inputs = BatchNorm(config)(inputs, - input_paddings, - update_batch_norm, - use_running_average_bn) + features=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + strides=(1,), + padding='SAME', + feature_group_count=config.encoder_dim, + use_bias=False, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) + + inputs = BatchNorm(config)( + inputs, input_paddings, update_batch_norm, use_running_average_bn + ) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': activation_fn = nn.gelu else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = nn.Dense( - config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.encoder_dim, kernel_init=nn.initializers.xavier_uniform() + )(inputs) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - inputs = nn.Dropout( - rate=conv_residual_dropout_rate, deterministic=not train)( - inputs) + inputs = Dropout(rate=dropout_rate, deterministic=not train)( + inputs, rate=dropout_rate + ) return inputs @@ -597,34 +613,42 @@ class ConformerBlock(nn.Module): y = layer_norm(x) """ + config: ConformerConfig @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm, - use_running_average): + def __call__( + self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average, + dropout_rate=DROPOUT_RATE, + ): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate + ) inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train) - - inputs = inputs + \ - ConvolutionBlock(config)(inputs, - input_paddings, - train, - update_batch_norm, - use_running_average - ) + inputs, input_paddings, train, dropout_rate=dropout_rate + ) + + inputs = inputs + ConvolutionBlock(config)( + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average, + dropout_rate, + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train, dropout_rate + ) if config.use_post_layer_norm: inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -639,26 +663,30 @@ class Conformer(nn.Module): for each time step. The output is then fed into a CTC loss which eliminates the need for alignment with targets. """ + config: ConformerConfig def setup(self): self.specaug = spectrum_augmenter.SpecAug( - freq_mask_count=self.config.freq_mask_count, - freq_mask_max_bins=self.config.freq_mask_max_bins, - time_mask_count=self.config.time_mask_count, - time_mask_max_frames=self.config.time_mask_max_frames, - time_mask_max_ratio=self.config.time_mask_max_ratio, - time_masks_per_frame=self.config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.config - .use_dynamic_time_mask_max_frames) + freq_mask_count=self.config.freq_mask_count, + freq_mask_max_bins=self.config.freq_mask_max_bins, + time_mask_count=self.config.time_mask_count, + time_mask_max_frames=self.config.time_mask_max_frames, + time_mask_max_ratio=self.config.time_mask_max_ratio, + time_masks_per_frame=self.config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=self.config.use_dynamic_time_mask_max_frames, + ) @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + def __call__( + self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None, + dropout_rate: float = DROPOUT_RATE, + ): config = self.config outputs = inputs @@ -675,38 +703,35 @@ def __call__(self, outputs, output_paddings = preprocessor.MelFilterbankFrontend( preprocessing_config, per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)( - outputs, output_paddings) + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + )(outputs, output_paddings) # Ablate random parts of input along temporal and frequency dimension # following the specaug procedure in https://arxiv.org/abs/1904.08779. if train and config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - # Subsample input by a factor of 4 by performing strided convolutions. - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate outputs, output_paddings = Subsample( - encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate)( - outputs, output_paddings, train) + encoder_dim=config.encoder_dim, + )(outputs, output_paddings, train, dropout_rate=dropout_rate) # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, - output_paddings, - train, - update_batch_norm, - use_running_average_bn) + outputs = ConformerBlock(config)( + outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn, + dropout_rate, + ) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. outputs = nn.Dense( - config.vocab_size, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + config.vocab_size, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..d9c1e301b 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -17,6 +17,7 @@ class SpecAug(nn.Module): This is an essential component in speech recognition models that helps achieve better word error rates. """ + freq_mask_count: int = 2 freq_mask_max_bins: int = 27 time_mask_count: int = 10 @@ -28,26 +29,30 @@ class SpecAug(nn.Module): def next_prng_key(self, name='dropout'): return self.make_rng(name) - def _get_mask(self, - batch_size, - choose_range, - mask_size, - max_length=None, - masks_per_frame=0.0, - multiplicity=1, - max_ratio=1.0): + def _get_mask( + self, + batch_size, + choose_range, + mask_size, + max_length=None, + masks_per_frame=0.0, + multiplicity=1, + max_ratio=1.0, + ): # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = jnp.tile(max_length, (batch_size,)) else: max_length = choose_range * max_ratio masked_portion = jax.random.uniform( - key=self.next_prng_key(), - shape=(batch_size, multiplicity), - minval=0.0, - maxval=1.0) - masked_frame_size = jnp.einsum('b,bm->bm', max_length, - masked_portion).astype(jnp.int32) + key=self.next_prng_key(), + shape=(batch_size, multiplicity), + minval=0.0, + maxval=1.0, + ) + masked_frame_size = jnp.einsum( + 'b,bm->bm', max_length, masked_portion + ).astype(jnp.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) @@ -57,7 +62,8 @@ def _get_mask(self, # Choose starting point. random_start = jax.random.uniform( - key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0) + key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0 + ) start_with_in_valid_range = random_start * (choose_range - length + 1) start = start_with_in_valid_range.astype(jnp.int32) @@ -78,11 +84,13 @@ def _get_mask(self, # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = jnp.tile( - jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), - [batch_size, 1]) + jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), + [batch_size, 1], + ) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = ( + multiplicity_weights < multiplicity_tensor + ).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) @@ -98,8 +106,9 @@ def _time_mask(self, inputs, length): max_ratio = self.time_mask_max_ratio # If maximum mask length is zero, do nothing. - if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or - max_ratio <= 0.0): + if ( + time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames + ) or max_ratio <= 0.0: return inputs if multiplicity == 0: return inputs @@ -111,13 +120,14 @@ def _time_mask(self, inputs, length): time_mask_max_frames = None # Create masks in time direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=length, - mask_size=time_length, - max_length=time_mask_max_frames, - masks_per_frame=self.time_masks_per_frame, - multiplicity=multiplicity, - max_ratio=max_ratio) + batch_size, + choose_range=length, + mask_size=time_length, + max_length=time_mask_max_frames, + masks_per_frame=self.time_masks_per_frame, + multiplicity=multiplicity, + max_ratio=max_ratio, + ) outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays) return outputs @@ -136,13 +146,14 @@ def _frequency_mask(self, inputs): choose_range = jnp.tile(num_freq, (batch_size,)) # Create masks in frequency direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=choose_range, - mask_size=num_freq, - max_length=freq_mask_max_bins, - masks_per_frame=0.0, - multiplicity=multiplicity, - max_ratio=1.0) + batch_size, + choose_range=choose_range, + mask_size=num_freq, + max_length=freq_mask_max_bins, + masks_per_frame=0.0, + multiplicity=multiplicity, + max_ratio=1.0, + ) outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays) return outputs diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..819e57a69 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,31 +2,28 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils -from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp import numpy as np import optax import torch +from flax import jax_utils +from flax.core import pop +from jax import lax -from algoperf import data_utils -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.librispeech_conformer import metrics -from algoperf.workloads.librispeech_conformer import workload -from algoperf.workloads.librispeech_conformer.input_pipeline import \ - LibriSpeechDataset +from algoperf import data_utils, param_utils, spec +from algoperf.workloads.librispeech_conformer import metrics, workload +from algoperf.workloads.librispeech_conformer.input_pipeline import ( + LibriSpeechDataset, +) from algoperf.workloads.librispeech_conformer.librispeech_jax import models class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): - - def __init__(self, - tokenizer_vocab_path: Optional[str] = None, - use_specaug: bool = True) -> None: + def __init__( + self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True + ) -> None: super().__init__() self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path) self.use_specaug = use_specaug @@ -38,7 +35,8 @@ def __init__(self, def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -58,13 +56,12 @@ def attention_temperature(self) -> float: return 1.0 def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + self, + rng: spec.RandomState, + ) -> spec.ModelInitState: """Conformer model init function. - Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as + Here we use dropout_rate as *_residual_dropout_rate, and for input_dropout_rate. """ if self.use_gelu: @@ -72,24 +69,22 @@ def init_model_fn( else: activation_function_name = 'swish' model_config = models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name) + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name, + ) + self._model = models.Conformer(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, - *fake_input_batch) + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn({'params': params_rng}, *fake_input_batch) - model_state, params = pop(variables, "params") + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -101,47 +96,52 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[float] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats'], - use_running_average_bn=use_running_average_bn) + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout': rng}, + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn, + dropout_rate=dropout_rate, + ) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False, - use_running_average_bn=use_running_average_bn) + variables, + inputs, + input_paddings, + train=False, + mutable=False, + use_running_average_bn=use_running_average_bn, + ) return (logits, logit_paddings), model_state def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del data_rng del cache del repeat_final_dataset @@ -160,38 +160,41 @@ def _build_input_queue( ds = LibriSpeechDataset(split=split, data_dir=data_dir) dataloader = data_utils.cycle( - torch.utils.data.DataLoader( - ds, - batch_size=global_batch_size, - shuffle=train, - sampler=None, - num_workers=4, - prefetch_factor=10, - pin_memory=False, - drop_last=train, - )) + torch.utils.data.DataLoader( + ds, + batch_size=global_batch_size, + shuffle=train, + sampler=None, + num_workers=4, + prefetch_factor=10, + pin_memory=False, + drop_last=train, + ) + ) for batch in iter(dataloader): inputs, input_paddings = batch['inputs'] targets, target_paddings = batch['targets'] numpy_batch = { - 'inputs': (inputs.numpy(), input_paddings.numpy()), - 'targets': (targets.numpy(), target_paddings.numpy()), + 'inputs': (inputs.numpy(), input_paddings.numpy()), + 'targets': (targets.numpy(), target_paddings.numpy()), } padded_batch = data_utils.shard_and_maybe_pad_np( - numpy_batch, padding_value=1.0) + numpy_batch, padding_value=1.0 + ) yield padded_batch # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) - logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) + logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -202,10 +205,9 @@ def loss_fn( logits, logit_paddings = logits_batch targets, target_paddings = label_batch logprobs = nn.log_softmax(logits) - per_example_losses = self.ctc_loss(logprobs, - logit_paddings, - targets, - target_paddings) + per_example_losses = self.ctc_loss( + logprobs, logit_paddings, targets, target_paddings + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -215,23 +217,26 @@ def loss_fn( n_valid_examples = jnp.maximum(mask_batch.sum(), 1) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } - def ctc_loss(self, - logits: spec.Tensor, - logit_paddings: spec.Tensor, - labels: spec.Tensor, - label_paddings: spec.Tensor, - blank_id: int = 0) -> spec.Tensor: + def ctc_loss( + self, + logits: spec.Tensor, + logit_paddings: spec.Tensor, + labels: spec.Tensor, + label_paddings: spec.Tensor, + blank_id: int = 0, + ) -> spec.Tensor: return optax.ctc_loss( - logits=logits, - logit_paddings=logit_paddings, - labels=labels, - label_paddings=label_paddings, - blank_id=blank_id) + logits=logits, + logit_paddings=logit_paddings, + labels=labels, + label_paddings=label_paddings, + blank_id=blank_id, + ) # Adapted from lingvo's greedy decoding logic here: # https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138. @@ -242,21 +247,22 @@ def sequence_mask(self, lengths: spec.Tensor, maxlen: int) -> spec.Tensor: c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype) return c - def collapse_and_remove_blanks(self, - labels: spec.Tensor, - seq_length: spec.Tensor, - blank_id: int = 0) -> spec.Tensor: + def collapse_and_remove_blanks( + self, labels: spec.Tensor, seq_length: spec.Tensor, blank_id: int = 0 + ) -> spec.Tensor: b, t = labels.shape # Zap out blank. blank_mask = 1 - jnp.equal(labels, blank_id) labels = (labels * blank_mask).astype(labels.dtype) # Mask labels that don't equal previous label. - label_mask = jnp.concatenate([ + label_mask = jnp.concatenate( + [ jnp.ones_like(labels[:, :1], dtype=jnp.int32), jnp.not_equal(labels[:, 1:], labels[:, :-1]), - ], - axis=1) + ], + axis=1, + ) # Filter labels that aren't in the original sequence. maxlen = labels.shape[1] @@ -292,12 +298,14 @@ def collapse_and_remove_blanks(self, # Reshape back to square batch. batch_size = labels.shape[0] new_shape = [batch_size, new_maxlen] - return (jnp.reshape(flat, new_shape).astype(labels.dtype), - new_seq_len.astype(seq_length.dtype)) + return ( + jnp.reshape(flat, new_shape).astype(labels.dtype), + new_seq_len.astype(seq_length.dtype), + ) def greedy_decode( - self, logits: spec.Tensor, - logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]: + self, logits: spec.Tensor, logit_paddings: spec.Tensor + ) -> Tuple[spec.Tensor, spec.Tensor]: per_frame_max = jnp.argmax(logits, axis=-1) seqlen = jnp.sum(1.0 - logit_paddings, axis=-1) hyp, _ = self.collapse_and_remove_blanks(per_frame_max, seqlen, blank_id=0) @@ -305,45 +313,51 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def eval_step_pmapped( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: (logits, logit_paddings), _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) targets, target_paddings = batch['targets'] return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + loss_dict=loss, + decoded=decoded, + decoded_paddings=decoded_paddings, + targets=targets, + target_paddings=target_paddings, + axis_name='batch', + ) + + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step if model_state is not None and len(model_state) > 0: @@ -353,15 +367,15 @@ def _eval_model_on_split(self, num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - rng, split, data_dir, global_batch_size, num_batches=num_batches) + rng, split, data_dir, global_batch_size, num_batches=num_batches + ) metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step_pmapped( + params, eval_batch, model_state, rng + ).unreplicate() if metrics_report is None: metrics_report = computed_metrics @@ -374,7 +388,8 @@ def _eval_model_on_split(self, return computed_metrics def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: + self, model_state: spec.ModelAuxiliaryState + ) -> spec.ModelAuxiliaryState: # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics and # we average them. @@ -385,8 +400,8 @@ def sync_batch_stats( class LibriSpeechConformerAttentionTemperatureWorkload( - LibriSpeechConformerWorkload): - + LibriSpeechConformerWorkload +): @property def attention_temperature(self) -> float: return 1.6 @@ -401,7 +416,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): - @property def use_post_layer_norm(self) -> bool: return False @@ -416,7 +430,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): - @property def use_gelu(self) -> bool: return True diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index db1e24521..647b8ff0c 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -2,34 +2,36 @@ https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. """ +import math from dataclasses import dataclass from functools import partial -import math from typing import Tuple import torch +import torch.nn.functional as F from torch import nn from torch.nn import init -import torch.nn.functional as F -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import ( + preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import ( + SpecAug, +) + +DROPOUT_RATE = 0.1 @dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 encoder_dim: int = 512 num_attention_heads: int = 8 num_encoder_layers: int = 4 attention_dropout_rate: float = 0.0 - attention_residual_dropout_rate: float = 0.1 - conv_residual_dropout_rate: float = 0.0 feed_forward_dropout_rate: float = 0.0 - feed_forward_residual_dropout_rate: float = 0.1 convolution_kernel_size: int = 5 feed_forward_expansion_factor: int = 4 freq_mask_count: int = 2 @@ -39,7 +41,6 @@ class ConformerConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - input_dropout_rate: float = 0.1 batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True @@ -60,7 +61,6 @@ def initialize(m): class LayerNorm(nn.Module): - def __init__(self, dim, epsilon=1e-6): super().__init__() self.dim = dim @@ -74,28 +74,25 @@ def forward(self, x): class Subsample(nn.Module): - - def __init__(self, - encoder_dim: int = 0, - input_dropout_rate: float = 0.0, - num_bins: int = 80): + def __init__(self, encoder_dim: int = 0, num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim - self.input_dropout_rate = input_dropout_rate self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim) + input_channels=1, output_channels=encoder_dim + ) self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, output_channels=encoder_dim) + input_channels=encoder_dim, output_channels=encoder_dim + ) self.linear = nn.Linear( - in_features=self.encoder_dim * num_bins // 4, - out_features=self.encoder_dim, - bias=True) + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True, + ) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -103,24 +100,27 @@ def forward(self, inputs, input_paddings): outputs, output_paddings = self.conv2(outputs, output_paddings) batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) + outputs = outputs.permute(0, 2, 3, 1).reshape( + batch_size, subsampled_lengths, subsampled_dims * channels + ) outputs = self.linear(outputs) outputs = outputs + self.pos_encode(seq_length=outputs.shape[1]) - outputs = self.dropout(outputs) + outputs = F.dropout( + outputs, dropout_rate, training=self.training, inplace=True + ) return outputs, output_paddings class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME'): + def __init__( + self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME', + ): super().__init__() self.input_channels = input_channels @@ -131,7 +131,8 @@ def __init__(self, self.filter_shape = (output_channels, input_channels, 3, 3) self.kernel = nn.Parameter( - torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)) + ) self.bias = nn.Parameter(torch.zeros(output_channels)) self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) @@ -161,12 +162,13 @@ def forward(self, inputs, paddings): else: in_ = inputs outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups, + ) outputs = F.relu(outputs) @@ -174,42 +176,37 @@ def forward(self, inputs, paddings): stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length padded_paddings = F.pad( - paddings[:, None, :], (0, pad_len), mode='constant', value=0) + paddings[:, None, :], (0, pad_len), mode='constant', value=0 + ) out_padding = F.conv1d( - input=padded_paddings, - weight=self.paddings_kernel, - stride=self.filter_stride[:1]) + input=padded_paddings, + weight=self.paddings_kernel, + stride=self.filter_stride[:1], + ) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) return outputs, out_padding class FeedForwardModule(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() self.config = config self.ln = LayerNorm(dim=config.encoder_dim) self.linear1 = nn.Linear( - in_features=config.encoder_dim, - out_features=config.encoder_dim * config.feed_forward_expansion_factor, - bias=True) + in_features=config.encoder_dim, + out_features=config.encoder_dim * config.feed_forward_expansion_factor, + bias=True, + ) self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) self.linear2 = nn.Linear( - in_features=config.encoder_dim * config.feed_forward_expansion_factor, - out_features=config.encoder_dim, - bias=True) - - if config.feed_forward_residual_dropout_rate is None: - feed_forward_residual_dropout_rate = 0.1 - else: - feed_forward_residual_dropout_rate = ( - config.feed_forward_residual_dropout_rate) - self.dropout2 = nn.Dropout( - p=feed_forward_residual_dropout_rate, inplace=True) + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True, + ) - def forward(self, inputs, padding_mask): + def forward(self, inputs, padding_mask, dropout_rate): inputs = self.ln(inputs) inputs = self.linear1(inputs) if self.config.activation_function_name == 'swish': @@ -218,51 +215,58 @@ def forward(self, inputs, padding_mask): # Use tanh approximation of GELU which is default for jax activation_fn = partial(F.gelu, approximate='tanh') else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = self.dropout1(inputs) inputs = inputs * padding_mask inputs = self.linear2(inputs) inputs = inputs * padding_mask - inputs = self.dropout2(inputs) + inputs = F.dropout( + inputs, dropout_rate, training=self.training, inplace=True + ) return inputs class AddPositionalEmbedding(nn.Module): - - def __init__(self, - min_timescale: int = 1, - max_timescale: int = 10_000, - embedding_dim: int = 512): + def __init__( + self, + min_timescale: int = 1, + max_timescale: int = 10_000, + embedding_dim: int = 512, + ): super().__init__() self.min_timescale = min_timescale self.max_timescale = max_timescale self.embedding_dim = embedding_dim num_timescales = self.embedding_dim // 2 log_timescale_increment = math.log( - float(self.max_timescale) / float(self.min_timescale)) / ( - num_timescales - 1) - inv_timescales = self.min_timescale * \ - torch.exp(torch.arange(num_timescales, dtype=torch.float32) - * -log_timescale_increment) + float(self.max_timescale) / float(self.min_timescale) + ) / (num_timescales - 1) + inv_timescales = self.min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float32) + * -log_timescale_increment + ) self.register_buffer('inv_timescales', inv_timescales[None, None, :]) def forward(self, seq_length): position = torch.arange( - end=seq_length, dtype=torch.float32, device=self.inv_timescales.device) + end=seq_length, dtype=torch.float32, device=self.inv_timescales.device + ) scaled_time = position[None, :, None] * self.inv_timescales signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) if self.embedding_dim % 2: signal = torch.cat( - [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2) + [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2 + ) return signal class QueryScaler(nn.Module): - def __init__(self, dim): super().__init__() self.dim = dim @@ -275,12 +279,11 @@ def forward(self, inputs): class MHSAwithQS(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() self.embed_dim = config.encoder_dim self.num_heads = config.num_attention_heads - self.dropout = config.attention_dropout_rate + self.attention_dropout_rate = config.attention_dropout_rate self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) @@ -292,20 +295,23 @@ def forward(self, inputs, key_padding_mask=None): q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) - out = F.scaled_dot_product_attention( + out = ( + F.scaled_dot_product_attention( query=q, key=k, value=v, attn_mask=~key_padding_mask[:, None, None], - dropout_p=self.dropout, - ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + dropout_p=self.attention_dropout_rate, + ) + .transpose(1, 2) + .reshape(batch_size, seq_len, embed_dim) + ) out = out * self.attention_temperature out = self.out_proj(out) return out class MultiHeadedSelfAttention(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() @@ -313,24 +319,20 @@ def __init__(self, config: ConformerConfig): self.ln = LayerNorm(dim=config.encoder_dim) self.self_attention = MHSAwithQS(config) - if config.attention_residual_dropout_rate is None: - attention_residual_dropout_rate = 0.1 - else: - attention_residual_dropout_rate = config.attention_residual_dropout_rate - self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True) - def forward(self, outputs, paddings): + def forward(self, outputs, paddings, dropout_rate): outputs = self.ln(outputs) outputs = self.self_attention( - outputs, - key_padding_mask=paddings == 1, + outputs, + key_padding_mask=paddings == 1, + ) + outputs = F.dropout( + outputs, dropout_rate, training=self.training, inplace=True ) - outputs = self.dropout(outputs) return outputs class BatchNorm(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() running_mean = torch.zeros(config.encoder_dim) @@ -345,8 +347,8 @@ def __init__(self, config: ConformerConfig): self.epsilon = config.batch_norm_epsilon def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH + # inputs: NHD + # padding: NH """ Alternatively: inputs[input_paddings==0] = F.batch_norm( @@ -370,9 +372,11 @@ def forward(self, inputs, input_paddings): var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() + self.momentum + ) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() + self.momentum + ) * var.detach() else: mean = self.running_mean @@ -384,34 +388,31 @@ def forward(self, inputs, input_paddings): class ConvolutionBlock(nn.Module): - def __init__(self, config): super().__init__() self.config = config self.ln = LayerNorm(dim=config.encoder_dim) self.lin1 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) + in_features=config.encoder_dim, out_features=config.encoder_dim + ) self.lin2 = nn.Linear( - in_features=config.encoder_dim, out_features=config.encoder_dim) + in_features=config.encoder_dim, out_features=config.encoder_dim + ) self.conv1 = nn.Conv1d( - in_channels=config.encoder_dim, - out_channels=config.encoder_dim, - kernel_size=(config.convolution_kernel_size,), - stride=(1,), - padding='same', - bias=False, - groups=config.encoder_dim) + in_channels=config.encoder_dim, + out_channels=config.encoder_dim, + kernel_size=(config.convolution_kernel_size,), + stride=(1,), + padding='same', + bias=False, + groups=config.encoder_dim, + ) self.bn = BatchNorm(config) self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim) - if config.conv_residual_dropout_rate is None: - conv_residual_dropout_rate = 0.0 - else: - conv_residual_dropout_rate = config.conv_residual_dropout_rate - self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): inputs = self.ln(inputs) inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2)) @@ -427,18 +428,21 @@ def forward(self, inputs, input_paddings): elif self.config.activation_function_name == 'gelu': activation_fn = F.gelu else: - raise ValueError('Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') + raise ValueError( + 'Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}' + ) inputs = activation_fn(inputs) inputs = self.lin3(inputs) - inputs = self.dropout(inputs) + inputs = F.dropout( + inputs, dropout_rate, training=self.training, inplace=True + ) return inputs class ConformerBlock(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() @@ -450,59 +454,57 @@ def __init__(self, config: ConformerConfig): if config.use_post_layer_norm: self.ln = LayerNorm(dim=config.encoder_dim) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = 1 - input_paddings[:, :, None] - inputs = inputs + 0.5 * self.ff1(inputs, padding_mask) - inputs = inputs + self.mhsa(inputs, input_paddings) - inputs = inputs + self.conv(inputs, input_paddings) - inputs = inputs + 0.5 * self.ff2(inputs, padding_mask) + inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate) + inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate) + inputs = inputs + self.conv(inputs, input_paddings, dropout_rate) + inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate) if self.ln: inputs = self.ln(inputs) return inputs class ConformerEncoderDecoder(nn.Module): - def __init__(self, config: ConformerConfig): super().__init__() self.config = config preprocessing_config = preprocessor.PreprocessorConfig() self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + ) self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( - encoder_dim=config.encoder_dim, - input_dropout_rate=input_dropout_rate, - num_bins=preprocessing_config.num_bins) + encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins + ) self.conformers = nn.ModuleList( - [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) + [ConformerBlock(config) for _ in range(config.num_encoder_layers)] + ) self.ln = LayerNorm(config.encoder_dim) self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample( + outputs, output_paddings, dropout_rate + ) for conformer in self.conformers: - outputs = conformer(outputs, output_paddings) + outputs = conformer(outputs, output_paddings, dropout_rate) outputs = self.ln(outputs) outputs = self.lin(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py index 558a0f796..58dd837dc 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py @@ -2,188 +2,189 @@ https://github.com/google/init2winit/blob/master/init2winit/model_lib/librispeech_preprocessor.py. """ -from dataclasses import dataclass import math +from dataclasses import dataclass from typing import Any, Optional, Union import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn # mel spectrum constants. _MEL_BREAK_FREQUENCY_HERTZ = 700.0 _MEL_HIGH_FREQUENCY_Q = 1127.0 LIBRISPEECH_MEAN_VECTOR = [ - -7.6047816276550293, - -7.1206226348876953, - -6.8864245414733887, - -6.8705768585205078, - -6.9667720794677734, - -7.1084094047546387, - -6.9528026580810547, - -6.783994197845459, - -6.6195521354675293, - -6.4876265525817871, - -6.4120659828186035, - -6.394047737121582, - -6.4244871139526367, - -6.3993711471557617, - -6.5158271789550781, - -6.7137999534606934, - -6.8476877212524414, - -6.9885001182556152, - -6.9221386909484863, - -7.146148681640625, - -7.2040400505065918, - -7.0537552833557129, - -7.3140382766723633, - -7.1223249435424805, - -7.30251407623291, - -7.1212143898010254, - -7.2425732612609863, - -7.1730537414550781, - -7.0979413986206055, - -7.088747501373291, - -6.9849910736083984, - -6.8787732124328613, - -6.7602753639221191, - -6.6300945281982422, - -6.5145769119262695, - -6.4245057106018066, - -6.356513500213623, - -6.31787633895874, - -6.2660770416259766, - -6.2468328475952148, - -6.2821526527404785, - -6.1908388137817383, - -6.2484354972839355, - -6.1472640037536621, - -6.0924725532531738, - -6.0171003341674805, - -5.9250402450561523, - -5.8535833358764648, - -5.8209109306335449, - -5.8118929862976074, - -5.80783748626709, - -5.7714629173278809, - -5.7453732490539551, - -5.7705655097961426, - -5.7765641212463379, - -5.7831673622131348, - -5.7954087257385254, - -5.7994823455810547, - -5.8023476600646973, - -5.8047118186950684, - -5.8168182373046875, - -5.8844799995422363, - -5.9727106094360352, - -6.0444660186767578, - -6.1284866333007812, - -6.2257585525512695, - -6.3157496452331543, - -6.39061164855957, - -6.4928598403930664, - -6.5498456954956055, - -6.6054320335388184, - -6.6508378982543945, - -6.66917610168457, - -6.6726889610290527, - -6.684234619140625, - -6.6974577903747559, - -6.75471830368042, - -6.7949142456054688, - -6.8634209632873535, - -6.94186544418335 + -7.6047816276550293, + -7.1206226348876953, + -6.8864245414733887, + -6.8705768585205078, + -6.9667720794677734, + -7.1084094047546387, + -6.9528026580810547, + -6.783994197845459, + -6.6195521354675293, + -6.4876265525817871, + -6.4120659828186035, + -6.394047737121582, + -6.4244871139526367, + -6.3993711471557617, + -6.5158271789550781, + -6.7137999534606934, + -6.8476877212524414, + -6.9885001182556152, + -6.9221386909484863, + -7.146148681640625, + -7.2040400505065918, + -7.0537552833557129, + -7.3140382766723633, + -7.1223249435424805, + -7.30251407623291, + -7.1212143898010254, + -7.2425732612609863, + -7.1730537414550781, + -7.0979413986206055, + -7.088747501373291, + -6.9849910736083984, + -6.8787732124328613, + -6.7602753639221191, + -6.6300945281982422, + -6.5145769119262695, + -6.4245057106018066, + -6.356513500213623, + -6.31787633895874, + -6.2660770416259766, + -6.2468328475952148, + -6.2821526527404785, + -6.1908388137817383, + -6.2484354972839355, + -6.1472640037536621, + -6.0924725532531738, + -6.0171003341674805, + -5.9250402450561523, + -5.8535833358764648, + -5.8209109306335449, + -5.8118929862976074, + -5.80783748626709, + -5.7714629173278809, + -5.7453732490539551, + -5.7705655097961426, + -5.7765641212463379, + -5.7831673622131348, + -5.7954087257385254, + -5.7994823455810547, + -5.8023476600646973, + -5.8047118186950684, + -5.8168182373046875, + -5.8844799995422363, + -5.9727106094360352, + -6.0444660186767578, + -6.1284866333007812, + -6.2257585525512695, + -6.3157496452331543, + -6.39061164855957, + -6.4928598403930664, + -6.5498456954956055, + -6.6054320335388184, + -6.6508378982543945, + -6.66917610168457, + -6.6726889610290527, + -6.684234619140625, + -6.6974577903747559, + -6.75471830368042, + -6.7949142456054688, + -6.8634209632873535, + -6.94186544418335, ] LIBRISPEECH_STD_VECTOR = [ - 3.4353282451629639, - 3.5962932109832764, - 3.7012472152709961, - 3.7369205951690674, - 3.7535104751586914, - 3.693629264831543, - 3.6922497749328613, - 3.7641522884368896, - 3.8419716358184814, - 3.8999848365783691, - 3.9294240474700928, - 3.9317409992218018, - 3.9139585494995117, - 3.9031598567962646, - 3.8691999912261963, - 3.8155081272125244, - 3.7644970417022705, - 3.7099106311798096, - 3.6965086460113525, - 3.6003766059875488, - 3.5493226051330566, - 3.5465121269226074, - 3.45003604888916, - 3.4712812900543213, - 3.4084610939025879, - 3.4408135414123535, - 3.4104881286621094, - 3.4217638969421387, - 3.4312851428985596, - 3.4199209213256836, - 3.4305806159973145, - 3.4382665157318115, - 3.4580366611480713, - 3.4817991256713867, - 3.4958710670471191, - 3.5036792755126953, - 3.5047574043273926, - 3.4988734722137451, - 3.493056058883667, - 3.4822943210601807, - 3.459430456161499, - 3.4612770080566406, - 3.4559063911437988, - 3.4755423069000244, - 3.4971549510955811, - 3.5326557159423828, - 3.5705199241638184, - 3.5920312404632568, - 3.596907377243042, - 3.5913500785827637, - 3.5865931510925293, - 3.5826809406280518, - 3.5837743282318115, - 3.5895791053771973, - 3.5819313526153564, - 3.5837869644165039, - 3.5861184597015381, - 3.5889589786529541, - 3.592214822769165, - 3.5939455032348633, - 3.5856630802154541, - 3.5884113311767578, - 3.5921022891998291, - 3.5870490074157715, - 3.5806570053100586, - 3.5731067657470703, - 3.5617532730102539, - 3.54980731010437, - 3.5527374744415283, - 3.5475366115570068, - 3.5387849807739258, - 3.5256178379058838, - 3.5031836032867432, - 3.4922726154327393, - 3.4879646301269531, - 3.4725594520568848, - 3.4558389186859131, - 3.4351828098297119, - 3.4284293651580811, - 3.4299170970916748 + 3.4353282451629639, + 3.5962932109832764, + 3.7012472152709961, + 3.7369205951690674, + 3.7535104751586914, + 3.693629264831543, + 3.6922497749328613, + 3.7641522884368896, + 3.8419716358184814, + 3.8999848365783691, + 3.9294240474700928, + 3.9317409992218018, + 3.9139585494995117, + 3.9031598567962646, + 3.8691999912261963, + 3.8155081272125244, + 3.7644970417022705, + 3.7099106311798096, + 3.6965086460113525, + 3.6003766059875488, + 3.5493226051330566, + 3.5465121269226074, + 3.45003604888916, + 3.4712812900543213, + 3.4084610939025879, + 3.4408135414123535, + 3.4104881286621094, + 3.4217638969421387, + 3.4312851428985596, + 3.4199209213256836, + 3.4305806159973145, + 3.4382665157318115, + 3.4580366611480713, + 3.4817991256713867, + 3.4958710670471191, + 3.5036792755126953, + 3.5047574043273926, + 3.4988734722137451, + 3.493056058883667, + 3.4822943210601807, + 3.459430456161499, + 3.4612770080566406, + 3.4559063911437988, + 3.4755423069000244, + 3.4971549510955811, + 3.5326557159423828, + 3.5705199241638184, + 3.5920312404632568, + 3.596907377243042, + 3.5913500785827637, + 3.5865931510925293, + 3.5826809406280518, + 3.5837743282318115, + 3.5895791053771973, + 3.5819313526153564, + 3.5837869644165039, + 3.5861184597015381, + 3.5889589786529541, + 3.592214822769165, + 3.5939455032348633, + 3.5856630802154541, + 3.5884113311767578, + 3.5921022891998291, + 3.5870490074157715, + 3.5806570053100586, + 3.5731067657470703, + 3.5617532730102539, + 3.54980731010437, + 3.5527374744415283, + 3.5475366115570068, + 3.5387849807739258, + 3.5256178379058838, + 3.5031836032867432, + 3.4922726154327393, + 3.4879646301269531, + 3.4725594520568848, + 3.4558389186859131, + 3.4351828098297119, + 3.4284293651580811, + 3.4299170970916748, ] @dataclass class PreprocessorConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + sample_rate = 16000 frame_size_ms = 25 frame_step_ms = 10 @@ -203,10 +204,12 @@ class PreprocessorConfig: def _hertz_to_mel(frequencies_hertz): """Convert hertz to mel.""" - log_fn = math.log if type(frequencies_hertz) in [type(0.0), type(0) - ] else torch.log - return _MEL_HIGH_FREQUENCY_Q * log_fn(1.0 + (frequencies_hertz / - _MEL_BREAK_FREQUENCY_HERTZ)) + log_fn = ( + math.log if type(frequencies_hertz) in [type(0.0), type(0)] else torch.log + ) + return _MEL_HIGH_FREQUENCY_Q * log_fn( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ) + ) def _pad_end_length(num_timesteps, frame_step, frame_size): @@ -218,28 +221,30 @@ def _pad_end_length(num_timesteps, frame_step, frame_size): return padded_length - num_timesteps -def frame(x, - frame_length: int, - frame_step: int, - pad_end: bool = False, - pad_value: Union[int, float] = 0.0): +def frame( + x, + frame_length: int, + frame_step: int, + pad_end: bool = False, + pad_value: Union[int, float] = 0.0, +): """Slides a window and extract values. - This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with - stride of `frame_step`, and returns an array `y` with the shape - `(batch_size, num_frames, frame_length, num_channels)`. Unlike the - counterpart in Tensorflow (`tf.signal.frame`), this function currently - does not take `axis` argument, and the input tensor `x` is expected to - have a shape of `(batch_size, timesteps, channels)`. - Args: - x: An input array with `(batch_size, timesteps, channels)`-shape. - frame_length: The frame length. - frame_step: The frame hop size. - pad_end: If True, the end of signal is padded so the window can continue - sliding while the starting point of the window is in the valid range. - pad_value: A scalar used as a padding value when `pad_end` is True. - Returns: - A tensor with shape `(*, num_frames, frame_length, num_channels)`. - """ + This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with + stride of `frame_step`, and returns an array `y` with the shape + `(batch_size, num_frames, frame_length, num_channels)`. Unlike the + counterpart in Tensorflow (`tf.signal.frame`), this function currently + does not take `axis` argument, and the input tensor `x` is expected to + have a shape of `(batch_size, timesteps, channels)`. + Args: + x: An input array with `(batch_size, timesteps, channels)`-shape. + frame_length: The frame length. + frame_step: The frame hop size. + pad_end: If True, the end of signal is padded so the window can continue + sliding while the starting point of the window is in the valid range. + pad_value: A scalar used as a padding value when `pad_end` is True. + Returns: + A tensor with shape `(*, num_frames, frame_length, num_channels)`. + """ num_timesteps = x.shape[1] if pad_end: @@ -250,60 +255,67 @@ def frame(x, return x.permute(0, 1, 3, 2) -def linear_to_mel_weight_matrix(num_mel_bins: int = 20, - num_spectrogram_bins: int = 129, - sample_rate: Union[int, float] = 8000, - lower_edge_hertz: Union[int, float] = 125.0, - upper_edge_hertz: Union[int, float] = 3800.0, - dtype: Any = torch.float32, - device='cpu'): +def linear_to_mel_weight_matrix( + num_mel_bins: int = 20, + num_spectrogram_bins: int = 129, + sample_rate: Union[int, float] = 8000, + lower_edge_hertz: Union[int, float] = 125.0, + upper_edge_hertz: Union[int, float] = 3800.0, + dtype: Any = torch.float32, + device='cpu', +): r"""Pytorch-port of `tf.signal.linear_to_mel_weight_matrix`. - Args: - num_mel_bins: Python int. How many bands in the resulting mel spectrum. - num_spectrogram_bins: An integer `Tensor`. How many bins there are in - the source spectrogram data, which is understood to be `fft_size // 2 + 1`, - i.e. the spectrogram only contains the nonredundant FFT bins. - sample_rate: An integer or float `Tensor`. Samples per second of the - input signal used to create the spectrogram. Used to figure out the - frequencies corresponding to each spectrogram bin, which dictates how they - are mapped into the mel scale. - lower_edge_hertz: Python float. Lower bound on the frequencies to be - included in the mel spectrum. This corresponds to the lower edge of the - lowest triangular band. - upper_edge_hertz: Python float. The desired top edge of the highest - frequency band. - dtype: The `DType` of the result matrix. Must be a floating point type. - Returns: - An array of shape `[num_spectrogram_bins, num_mel_bins]`. - Raises: - ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not - positive, `lower_edge_hertz` is negative, frequency edges are incorrectly - ordered, `upper_edge_hertz` is larger than the Nyquist frequency. - [mel]: https://en.wikipedia.org/wiki/Mel_scale - """ + Args: + num_mel_bins: Python int. How many bands in the resulting mel spectrum. + num_spectrogram_bins: An integer `Tensor`. How many bins there are in + the source spectrogram data, which is understood to be `fft_size // 2 + 1`, + i.e. the spectrogram only contains the nonredundant FFT bins. + sample_rate: An integer or float `Tensor`. Samples per second of the + input signal used to create the spectrogram. Used to figure out the + frequencies corresponding to each spectrogram bin, which dictates how they + are mapped into the mel scale. + lower_edge_hertz: Python float. Lower bound on the frequencies to be + included in the mel spectrum. This corresponds to the lower edge of the + lowest triangular band. + upper_edge_hertz: Python float. The desired top edge of the highest + frequency band. + dtype: The `DType` of the result matrix. Must be a floating point type. + Returns: + An array of shape `[num_spectrogram_bins, num_mel_bins]`. + Raises: + ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not + positive, `lower_edge_hertz` is negative, frequency edges are incorrectly + ordered, `upper_edge_hertz` is larger than the Nyquist frequency. + [mel]: https://en.wikipedia.org/wiki/Mel_scale + """ # Input validator from tensorflow/python/ops/signal/mel_ops.py#L71 if num_mel_bins <= 0: raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) if lower_edge_hertz < 0.0: - raise ValueError('lower_edge_hertz must be non-negative. Got: %s' % - lower_edge_hertz) + raise ValueError( + 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz + ) if lower_edge_hertz >= upper_edge_hertz: - raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % - (lower_edge_hertz, upper_edge_hertz)) + raise ValueError( + 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f' + % (lower_edge_hertz, upper_edge_hertz) + ) if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: - raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' - 'frequency (sample_rate / 2). Got %s for sample_rate: %s' % - (upper_edge_hertz, sample_rate)) + raise ValueError( + 'upper_edge_hertz must not be larger than the Nyquist ' + 'frequency (sample_rate / 2). Got %s for sample_rate: %s' + % (upper_edge_hertz, sample_rate) + ) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = torch.linspace( - 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype, - device=device)[bands_to_zero:] + 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype, device=device + )[bands_to_zero:] spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, None] # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The @@ -311,11 +323,12 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into # num_mel_bins + 2 pieces. edges = torch.linspace( - _hertz_to_mel(lower_edge_hertz), - _hertz_to_mel(upper_edge_hertz), - num_mel_bins + 2, - dtype=dtype, - device=device) + _hertz_to_mel(lower_edge_hertz), + _hertz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + dtype=dtype, + device=device, + ) # Split the triples up and reshape them into [1, num_mel_bins] tensors. lower_edge_mel = edges[:-2][None, :] @@ -325,13 +338,16 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Calculate lower and upper slopes for every spectrogram bin. # Line segments are linear in the mel domain, not Hertz. lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( - center_mel - lower_edge_mel) + center_mel - lower_edge_mel + ) upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( - upper_edge_mel - center_mel) + upper_edge_mel - center_mel + ) # Intersect the line segments with each other and zero. mel_weights_matrix = torch.minimum(lower_slopes, upper_slopes).clamp( - min=0.0, max=None) + min=0.0, max=None + ) # Re-add the zeroed lower bins we sliced out above. return F.pad(mel_weights_matrix, (0, 0, bands_to_zero, 0)) @@ -339,43 +355,50 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, def _hanning_greco(win_support, frame_size, dtype, device='cpu'): """Add a greco-style hanning window to the graph. - Note that the Hanning window in Wikipedia is not the same as the Hanning - window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia - page would indicate. Talkin's explanation was that it was like wasting two - samples to have the values at the edge of the window to be 0.0 exactly. - Args: - win_support: Number of samples for non-zero support in the window - frame_size: Total size of the window (frame_size >= win_support) - dtype: TF data type - Returns: - Tensor of size frame_size with the window to apply. - """ + Note that the Hanning window in Wikipedia is not the same as the Hanning + window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia + page would indicate. Talkin's explanation was that it was like wasting two + samples to have the values at the edge of the window to be 0.0 exactly. + Args: + win_support: Number of samples for non-zero support in the window + frame_size: Total size of the window (frame_size >= win_support) + dtype: TF data type + Returns: + Tensor of size frame_size with the window to apply. + """ if frame_size < win_support: raise ValueError( - 'Provided frame_size = {} is lower than win_support = {}'.format( - frame_size, win_support)) + 'Provided frame_size = {} is lower than win_support = {}'.format( + frame_size, win_support + ) + ) arg = torch.pi * 2.0 / (win_support) - hann = 0.5 - (0.5 * torch.cos( - arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5))) + hann = 0.5 - ( + 0.5 + * torch.cos( + arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5) + ) + ) zero_size = frame_size - win_support return F.pad(hann, (0, zero_size)) def _next_pow_of_two(x: Union[int, float]) -> int: - return int(2**np.ceil(np.log2(x))) + return int(2 ** np.ceil(np.log2(x))) class SpectrogramFrontend(nn.Module): - """Layer to convert input audio signals from time domain to frequency domain. - """ - - def __init__(self, - config: PreprocessorConfig = None, - input_scale_factor: float = 1.0, - output_log: bool = False, - dtype=torch.float32, - device='cpu'): + """Layer to convert input audio signals from time domain to frequency domain.""" + + def __init__( + self, + config: PreprocessorConfig = None, + input_scale_factor: float = 1.0, + output_log: bool = False, + dtype=torch.float32, + device='cpu', + ): super().__init__() self.config = config @@ -384,8 +407,9 @@ def __init__(self, p = self.config self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0)) - self._frame_size = int(round( - p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph + self._frame_size = ( + int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1 + ) # +1 for the preemph # TF-version has maximum of 512, but it's not always necessary self.fft_size = _next_pow_of_two(self._frame_size) @@ -399,23 +423,20 @@ def _hanning_window(frame_size, dtype): if frame_size % 2 == 0: # simulate periodic=True in tf.signal.hann_window return torch.hann_window( - window_length=frame_size, - periodic=True, - dtype=dtype, - device=device) + window_length=frame_size, periodic=True, dtype=dtype, device=device + ) else: return torch.hann_window( - window_length=frame_size, - periodic=False, - dtype=dtype, - device=device) + window_length=frame_size, periodic=False, dtype=dtype, device=device + ) self._window_fn = _hanning_window elif p.window_fn.upper() == 'HANNING_GRECO': # Greco-compatible hanning window def f(frame_size, dtype): return _hanning_greco( - self._frame_size - 1, frame_size, dtype, device=device) + self._frame_size - 1, frame_size, dtype, device=device + ) self._window_fn = f else: @@ -430,25 +451,31 @@ def f(frame_size, dtype): def _apply_preemphasis(self, framed_signal): p = self.config if p.preemph_htk_flavor: - return torch.cat([ - framed_signal[:, :, :1, :] * (1. - p.preemph), - (framed_signal[:, :, 1:-1, :] - - p.preemph * framed_signal[:, :, :-2, :]) - ], - dim=2) + return torch.cat( + [ + framed_signal[:, :, :1, :] * (1.0 - p.preemph), + ( + framed_signal[:, :, 1:-1, :] + - p.preemph * framed_signal[:, :, :-2, :] + ), + ], + dim=2, + ) else: - return (framed_signal[:, :, 1:, :] - - p.preemph * framed_signal[:, :, :-1, :]) + return ( + framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :] + ) def fprop_paddings(self, input_paddings): p = self.config if p.pad_end: - num_extends = _pad_end_length(input_paddings.shape[1], - self._frame_step, - self._frame_size) + num_extends = _pad_end_length( + input_paddings.shape[1], self._frame_step, self._frame_size + ) input_paddings = F.pad(input_paddings, (0, num_extends), value=1.0) x = input_paddings.unfold( - dimension=1, size=self._frame_size, step=self._frame_step) + dimension=1, size=self._frame_size, step=self._frame_step + ) return x.min(dim=2)[0] def forward(self, inputs, input_paddings): @@ -467,7 +494,8 @@ def forward(self, inputs, input_paddings): pcm_audio_chunk = inputs * self.input_scale_factor framed_signal = frame( - pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end) + pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end + ) if p.preemph != 0.0: preemphasized = self._apply_preemphasis(framed_signal) @@ -497,12 +525,14 @@ def forward(self, inputs, input_paddings): class MelFilterbankFrontend(nn.Module): """Layer to compute log mel spectograms from input audio signals.""" - def __init__(self, - config: PreprocessorConfig = None, - use_divide_stream: bool = True, - per_bin_mean: Optional[float] = None, - per_bin_stddev: Optional[float] = None, - device='cpu'): + def __init__( + self, + config: PreprocessorConfig = None, + use_divide_stream: bool = True, + per_bin_mean: Optional[float] = None, + per_bin_stddev: Optional[float] = None, + device='cpu', + ): super().__init__() self.config = config @@ -513,7 +543,8 @@ def __init__(self, input_scale_factor = 2**-15 if self.use_divide_stream else 1.0 self.stft = SpectrogramFrontend( - p, input_scale_factor=input_scale_factor, output_log=False) + p, input_scale_factor=input_scale_factor, output_log=False + ) if self.per_bin_mean is None: per_bin_mean = [0.0] * p.num_bins @@ -525,10 +556,13 @@ def __init__(self, else: per_bin_stddev = self.per_bin_stddev - self.register_buffer('_normalizer_mean', - torch.FloatTensor(per_bin_mean)[None, None, :, None]) - self.register_buffer('_normalizer_stddev', - torch.FloatTensor(per_bin_stddev)[None, None, :, None]) + self.register_buffer( + '_normalizer_mean', torch.FloatTensor(per_bin_mean)[None, None, :, None] + ) + self.register_buffer( + '_normalizer_stddev', + torch.FloatTensor(per_bin_stddev)[None, None, :, None], + ) def forward(self, inputs, input_paddings): p = self.config @@ -536,20 +570,24 @@ def forward(self, inputs, input_paddings): spect, spect_paddings = self.stft(inputs, input_paddings) mel_weights = linear_to_mel_weight_matrix( - num_mel_bins=p.num_bins, - num_spectrogram_bins=spect.shape[2], - sample_rate=p.sample_rate, - lower_edge_hertz=p.lower_edge_hertz, - upper_edge_hertz=p.upper_edge_hertz, - device=spect.device) + num_mel_bins=p.num_bins, + num_spectrogram_bins=spect.shape[2], + sample_rate=p.sample_rate, + lower_edge_hertz=p.lower_edge_hertz, + upper_edge_hertz=p.upper_edge_hertz, + device=spect.device, + ) mel_spectrogram = torch.einsum('fn,btfc->btnc', mel_weights, spect) logmel_spectrogram = torch.log( - mel_spectrogram.clamp(min=p.output_floor, max=None)) + mel_spectrogram.clamp(min=p.output_floor, max=None) + ) normalized_logmel_spectrogram = ( - (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev) + logmel_spectrogram - self._normalizer_mean + ) / self._normalizer_stddev - normalized_logmel_spectrogram = torch.squeeze(normalized_logmel_spectrogram, - -1) + normalized_logmel_spectrogram = torch.squeeze( + normalized_logmel_spectrogram, -1 + ) return normalized_logmel_spectrogram, spect_paddings diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py index 11b93703e..66db657b8 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py @@ -9,19 +9,21 @@ class SpecAug(nn.Module): """Layer performs masking prodecure along time and frequency axis. - The procedure is detailed in https://arxiv.org/abs/1904.08779. - This is an essential component in speech recognition models that - helps achieve better word error rates. - """ - - def __init__(self, - freq_mask_count: int = 1, - freq_mask_max_bins: int = 15, - time_mask_count: int = 1, - time_mask_max_frames: int = 50, - time_mask_max_ratio: float = 1.0, - time_masks_per_frame: float = 0.0, - use_dynamic_time_mask_max_frames: bool = False): + The procedure is detailed in https://arxiv.org/abs/1904.08779. + This is an essential component in speech recognition models that + helps achieve better word error rates. + """ + + def __init__( + self, + freq_mask_count: int = 1, + freq_mask_max_bins: int = 15, + time_mask_count: int = 1, + time_mask_max_frames: int = 50, + time_mask_max_ratio: float = 1.0, + time_masks_per_frame: float = 0.0, + use_dynamic_time_mask_max_frames: bool = False, + ): super().__init__() self.freq_mask_count = freq_mask_count @@ -35,23 +37,26 @@ def __init__(self, def next_prng_key(self, name='dropout'): return self.make_rng(name) - def _get_mask(self, - batch_size, - choose_range, - mask_size, - max_length=None, - masks_per_frame=0.0, - multiplicity=1, - max_ratio=1.0, - device='cpu'): + def _get_mask( + self, + batch_size, + choose_range, + mask_size, + max_length=None, + masks_per_frame=0.0, + multiplicity=1, + max_ratio=1.0, + device='cpu', + ): # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = max_length * torch.ones(batch_size, device=device) else: max_length = choose_range * max_ratio masked_portion = torch.rand(batch_size, multiplicity, device=device) - masked_frame_size = torch.einsum('b,bm->bm', max_length, - masked_portion).long() + masked_frame_size = torch.einsum( + 'b,bm->bm', max_length, masked_portion + ).long() # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) @@ -80,8 +85,9 @@ def _get_mask(self, # Sum masks with appropriate multiplicity. if masks_per_frame > 0: multiplicity_weights = torch.tile( - torch.arange(multiplicity, device=device).long()[None, ...], - [batch_size, 1]) + torch.arange(multiplicity, device=device).long()[None, ...], + [batch_size, 1], + ) multiplicity_tensor = masks_per_frame * choose_range multiplicity_weights = (multiplicity_weights < multiplicity_tensor).long() pre_mask = torch.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) @@ -99,8 +105,9 @@ def _time_mask(self, inputs, length): max_ratio = self.time_mask_max_ratio # If maximum mask length is zero, do nothing. - if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or - max_ratio <= 0.0): + if ( + time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames + ) or max_ratio <= 0.0: return inputs if multiplicity == 0: return inputs @@ -112,14 +119,15 @@ def _time_mask(self, inputs, length): time_mask_max_frames = None # Create masks in time direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=length, - mask_size=time_length, - max_length=time_mask_max_frames, - masks_per_frame=self.time_masks_per_frame, - multiplicity=multiplicity, - max_ratio=max_ratio, - device=inputs.device) + batch_size, + choose_range=length, + mask_size=time_length, + max_length=time_mask_max_frames, + masks_per_frame=self.time_masks_per_frame, + multiplicity=multiplicity, + max_ratio=max_ratio, + device=inputs.device, + ) outputs = torch.einsum('bxy,bx->bxy', inputs, block_arrays) return outputs @@ -138,14 +146,15 @@ def _frequency_mask(self, inputs): choose_range = num_freq * torch.ones(batch_size, device=inputs.device) # Create masks in frequency direction and apply. block_arrays = self._get_mask( - batch_size, - choose_range=choose_range, - mask_size=num_freq, - max_length=freq_mask_max_bins, - masks_per_frame=0.0, - multiplicity=multiplicity, - max_ratio=1.0, - device=inputs.device) + batch_size, + choose_range=choose_range, + mask_size=num_freq, + max_length=freq_mask_max_bins, + masks_per_frame=0.0, + multiplicity=multiplicity, + max_ratio=1.0, + device=inputs.device, + ) outputs = torch.einsum('bxy,by->bxy', inputs, block_arrays) return outputs diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 5ed37957e..25416682c 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -10,15 +10,12 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import data_utils -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec import algoperf.random_utils as prng -from algoperf.workloads.librispeech_conformer import metrics -from algoperf.workloads.librispeech_conformer import workload -from algoperf.workloads.librispeech_conformer.input_pipeline import \ - LibriSpeechDataset +from algoperf import data_utils, param_utils, pytorch_utils, spec +from algoperf.workloads.librispeech_conformer import metrics, workload +from algoperf.workloads.librispeech_conformer.input_pipeline import ( + LibriSpeechDataset, +) from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -27,10 +24,9 @@ class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): - - def __init__(self, - tokenizer_vocab_path: Optional[str] = None, - use_specaug: bool = True) -> None: + def __init__( + self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True + ) -> None: super().__init__() self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path) self.use_specaug = use_specaug @@ -42,7 +38,8 @@ def __init__(self, def eval_num_workers(self) -> int: if self._eval_num_workers is None: raise ValueError( - 'eval_num_workers property must be set before workload is used.') + 'eval_num_workers property must be set before workload is used.' + ) return self._eval_num_workers @eval_num_workers.setter @@ -61,16 +58,8 @@ def use_gelu(self) -> bool: def attention_temperature(self) -> float: return 1.0 - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Conformer model init function. - - Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as - input_dropout_rate. - """ + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + """Conformer model init function.""" torch.random.manual_seed(rng[0]) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False @@ -82,15 +71,13 @@ def init_model_fn( else: activation_function_name = 'swish' model = models.ConformerEncoderDecoder( - models.ConformerConfig( - attention_residual_dropout_rate=dropout_rate, - feed_forward_residual_dropout_rate=dropout_rate, - conv_residual_dropout_rate=dropout_rate, - input_dropout_rate=aux_dropout_rate, - use_specaug=self.use_specaug, - attention_temperature=self.attention_temperature, - use_post_layer_norm=self.use_post_layer_norm, - activation_function_name=activation_function_name)) + models.ConformerConfig( + use_specaug=self.use_specaug, + attention_temperature=self.attention_temperature, + use_post_layer_norm=self.use_post_layer_norm, + activation_function_name=activation_function_name, + ) + ) self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') models.initialize(model) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -109,13 +96,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng @@ -125,29 +114,33 @@ def model_fn( if mode == spec.ForwardPassMode.TRAIN: model.train() model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm, + ) + ) contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - logits, logits_paddings = model(inputs.to(DEVICE), - input_paddings.to(DEVICE)) + logits, logits_paddings = model( + inputs.to(DEVICE), input_paddings.to(DEVICE), dropout_rate=dropout_rate + ) return (logits, logits_paddings), None def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache del repeat_final_dataset del num_batches @@ -166,7 +159,7 @@ def _build_input_queue( if split == 'eval_train': indices = list(range(len(ds))) random.Random(int(data_rng[0])).shuffle(indices) - ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples]) + ds = torch.utils.data.Subset(ds, indices[: self.num_eval_train_examples]) sampler = None if USE_PYTORCH_DDP: @@ -177,31 +170,36 @@ def _build_input_queue( if USE_PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - ds, num_replicas=N_GPUS, rank=RANK, shuffle=True) + ds, num_replicas=N_GPUS, rank=RANK, shuffle=True + ) else: sampler = data_utils.DistributedEvalSampler( - ds, num_replicas=N_GPUS, rank=RANK, shuffle=False) + ds, num_replicas=N_GPUS, rank=RANK, shuffle=False + ) dataloader = torch.utils.data.DataLoader( - ds, - batch_size=ds_iter_batch_size, - shuffle=not USE_PYTORCH_DDP and is_train, - sampler=sampler, - num_workers=4, - pin_memory=True, - drop_last=is_train) + ds, + batch_size=ds_iter_batch_size, + shuffle=not USE_PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=4, + pin_memory=True, + drop_last=is_train, + ) dataloader = data_utils.cycle( - dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False) + dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False + ) return dataloader # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) - logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) + logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -215,10 +213,8 @@ def loss_fn( input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long() target_lengths = torch.einsum('bh->b', 1 - target_paddings).long() per_example_losses = self.ctc_loss( - logprobs.permute(1, 0, 2), - targets.long(), - input_lengths, - target_lengths) + logprobs.permute(1, 0, 2), targets.long(), input_lengths, target_lengths + ) # mask_batch is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -229,21 +225,22 @@ def loss_fn( summed_loss = per_example_losses.sum() n_valid_examples = max(n_valid_examples, 1) return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def greedy_decode( - self, logits: spec.Tensor, - logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]: + self, logits: spec.Tensor, logit_paddings: spec.Tensor + ) -> Tuple[spec.Tensor, spec.Tensor]: framewise_tokens = logits.max(dim=-1)[1] framewise_tokens = framewise_tokens * (1 - logit_paddings) # Add sentinel because unique_consecutive will flatten array # and then compute the unique. framewise_tokens = torch.cat( - [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1) + [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1 + ) _, indices = torch.unique_consecutive(framewise_tokens, return_inverse=True) indices -= indices.min(dim=1, keepdims=True)[0] result = torch.zeros_like(framewise_tokens) @@ -256,11 +253,12 @@ def greedy_decode( # Remove blanks (id = 0). blank_id = 0 fin_result = torch.zeros_like(result) - idxs = torch.arange( - fin_result.numel(), device=result.device).view(*fin_result.shape) - mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + idxs = torch.arange(fin_result.numel(), device=result.device).view( + *fin_result.shape + ) + mask = torch.arange(fin_result.shape[1], device=result.device).view( + 1, -1 + ) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -274,29 +272,31 @@ def sync_sd(self, params: spec.ParameterContainer) -> None: sd[k] = sd[k] / N_GPUS params.load_state_dict(sd) - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: # These iterators repeat indefinitely. - self._eval_iters[split] = ( - self._build_input_queue( - data_rng, split, data_dir, global_batch_size=global_batch_size)) + self._eval_iters[split] = self._build_input_queue( + data_rng, split, data_dir, global_batch_size=global_batch_size + ) total_metrics = { - 'loss': torch.tensor(0., device=DEVICE), - 'lengths': torch.tensor(0., device=DEVICE), - 'word_errors': torch.tensor(0., device=DEVICE), - 'num_words': torch.tensor(0., device=DEVICE), + 'loss': torch.tensor(0.0, device=DEVICE), + 'lengths': torch.tensor(0.0, device=DEVICE), + 'word_errors': torch.tensor(0.0, device=DEVICE), + 'num_words': torch.tensor(0.0, device=DEVICE), } num_batches = int(math.ceil(num_examples / global_batch_size)) if self.requires_sync_before_eval: @@ -305,48 +305,50 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) (logits, logits_padding), _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False, + ) decoded, decoded_paddings = self.greedy_decode(logits, logits_padding) targets, target_paddings = batch['targets'] word_errors, num_words = metrics.compute_wer( - decoded=decoded.cpu().numpy(), - decoded_paddings=decoded_paddings.cpu().numpy(), - targets=targets.cpu().numpy(), - target_paddings=target_paddings.cpu().numpy(), - tokenizer=self.tokenizer) + decoded=decoded.cpu().numpy(), + decoded_paddings=decoded_paddings.cpu().numpy(), + targets=targets.cpu().numpy(), + target_paddings=target_paddings.cpu().numpy(), + tokenizer=self.tokenizer, + ) loss = self.loss_fn((targets, target_paddings), (logits, logits_padding)) summed_loss = loss['summed'] lengths = loss['n_valid_examples'] batch_metrics = { - 'loss': summed_loss, - 'lengths': lengths, - 'word_errors': word_errors, - 'num_words': num_words, + 'loss': summed_loss, + 'lengths': lengths, + 'word_errors': word_errors, + 'num_words': num_words, } total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) return { - 'ctc_loss': - float(total_metrics['loss'].item() / - total_metrics['lengths'].item()), - 'wer': - float(total_metrics['word_errors'].item() / - total_metrics['num_words'].item()), + 'ctc_loss': float( + total_metrics['loss'].item() / total_metrics['lengths'].item() + ), + 'wer': float( + total_metrics['word_errors'].item() / total_metrics['num_words'].item() + ), } class LibriSpeechConformerAttentionTemperatureWorkload( - LibriSpeechConformerWorkload): - + LibriSpeechConformerWorkload +): @property def attention_temperature(self) -> float: return 1.6 @@ -361,7 +363,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): - @property def use_post_layer_norm(self) -> bool: return False @@ -376,7 +377,6 @@ def test_target_value(self) -> float: class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): - @property def use_gelu(self) -> bool: return True diff --git a/algoperf/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py index de74cfe1b..7dd6a11dc 100644 --- a/algoperf/workloads/librispeech_conformer/metrics.py +++ b/algoperf/workloads/librispeech_conformer/metrics.py @@ -1,8 +1,8 @@ -from clu import metrics import flax import numpy as np import tensorflow as tf import tensorflow_text as tftxt +from clu import metrics gfile = tf.io.gfile @@ -15,17 +15,20 @@ def average_ctc_loss(): @flax.struct.dataclass class _Metric(metrics.Metric): """Applies `fun` and computes the average.""" + total: np.float32 weight: np.float32 @classmethod def from_model_output(cls, loss_dict, **_): return cls( - total=loss_dict['summed'], weight=loss_dict['n_valid_examples']) + total=loss_dict['summed'], weight=loss_dict['n_valid_examples'] + ) def merge(self, other): return type(self)( - total=self.total + other.total, weight=self.weight + other.weight) + total=self.total + other.total, weight=self.weight + other.weight + ) def compute(self): return self.total / self.weight @@ -74,9 +77,10 @@ def edit_distance(source, target): # possibilities and find minimum. else: distance[i][j] = 1 + min( - distance[i][j - 1], # Insert - distance[i - 1][j], # Remove - distance[i - 1][j - 1]) # Replace + distance[i][j - 1], # Insert + distance[i - 1][j], # Remove + distance[i - 1][j - 1], + ) # Replace return distance[num_source_words][num_target_words] @@ -109,17 +113,20 @@ def compute_wer(decoded, decoded_paddings, targets, target_paddings, tokenizer): return word_errors, num_words -def load_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def load_tokenizer( + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" if model_path is None: return None with gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer @@ -128,8 +135,10 @@ def wer(tokenizer_vocab_path): @flax.struct.dataclass class WER( - metrics.CollectingMetric.from_outputs( - ('decoded', 'decoded_paddings', 'targets', 'target_paddings'))): + metrics.CollectingMetric.from_outputs( + ('decoded', 'decoded_paddings', 'targets', 'target_paddings') + ) + ): """Computes the mean average precision for a binary classifier on CPU.""" def compute(self): @@ -144,7 +153,8 @@ def compute(self): values['decoded_paddings'], values['targets'].astype(np.int32), values['target_paddings'], - tokenizer) + tokenizer, + ) return word_errors / num_words @@ -153,4 +163,5 @@ def compute(self): def get_metrics_bundle(tokenizer_vocab_path): return metrics.Collection.create( - ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path)) + ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path) + ) diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 94f01dd97..791270719 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -5,7 +5,6 @@ class BaseLibrispeechWorkload(spec.Workload): - _num_outputs: int = 1024 @property @@ -25,8 +24,9 @@ def use_gelu(self) -> bool: def attention_temperature(self) -> float: raise NotImplementedError - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/wer'] < self.validation_target_value @property @@ -53,8 +53,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..225852b28 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -1,4 +1,4 @@ -r"""Deepspeech. +"""Deepspeech. This model uses a deepspeech2 network to convert speech to text. paper : https://arxiv.org/abs/1512.02595 @@ -10,16 +10,19 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import jax +import jax.numpy as jnp from flax import linen as nn from flax import struct -import jax from jax.experimental import rnn -import jax.numpy as jnp -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - librispeech_preprocessor as preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - spectrum_augmenter +from algoperf.jax_utils import Dropout +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + librispeech_preprocessor as preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax import ( + spectrum_augmenter, +) Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] @@ -30,10 +33,13 @@ CarryHistory = Any Output = Any +DROPOUT_RATE = 0.1 + @struct.dataclass class DeepspeechConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 dtype: Any = jnp.float32 encoder_dim: int = 512 @@ -51,10 +57,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -69,50 +71,49 @@ class Subsample(nn.Module): encoder_dim: model dimension of conformer. input_dropout_rate: dropout rate for inputs. """ + config: DeepspeechConfig @nn.compact - def __call__(self, inputs, output_paddings, train): + def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( - encoder_dim=config.encoder_dim, - dtype=config.dtype, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon, - input_channels=1, - output_channels=config.encoder_dim, - use_tanh=config.use_tanh - )(outputs, output_paddings, train) + encoder_dim=config.encoder_dim, + dtype=config.dtype, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + input_channels=1, + output_channels=config.encoder_dim, + use_tanh=config.use_tanh, + )(outputs, output_paddings, train) outputs, output_paddings = Conv2dSubsampling( - encoder_dim=config.encoder_dim, - dtype=config.dtype, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon, - input_channels=config.encoder_dim, - output_channels=config.encoder_dim, - use_tanh=config.use_tanh)(outputs, output_paddings, train) + encoder_dim=config.encoder_dim, + dtype=config.dtype, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + input_channels=config.encoder_dim, + output_channels=config.encoder_dim, + use_tanh=config.use_tanh, + )(outputs, output_paddings, train) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels) + ) outputs = nn.Dense( - config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - outputs = nn.Dropout( - rate=input_dropout_rate, deterministic=not train)( - outputs) + outputs = Dropout(rate=dropout_rate, deterministic=not train)( + outputs, rate=dropout_rate + ) return outputs, output_paddings @@ -124,6 +125,7 @@ class Conv2dSubsampling(nn.Module): 2) Also performs strided convolution over input_paddings to return the correct paddings for downstream layers. """ + input_channels: int = 0 output_channels: int = 0 filter_stride: List[int] = (2, 2) @@ -136,24 +138,26 @@ class Conv2dSubsampling(nn.Module): def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', - nn.initializers.xavier_uniform(), - self.filter_shape) + self.kernel = self.param( + 'kernel', nn.initializers.xavier_uniform(), self.filter_shape + ) self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels + ) @nn.compact def __call__(self, inputs, paddings, train): # Computing strided convolution to subsample inputs. feature_group_count = inputs.shape[3] // self.filter_shape[2] outputs = jax.lax.conv_general_dilated( - lhs=inputs, - rhs=self.kernel, - window_strides=self.filter_stride, - padding=self.padding, - rhs_dilation=(1, 1), - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) + lhs=inputs, + rhs=self.kernel, + window_strides=self.filter_stride, + padding=self.padding, + rhs_dilation=(1, 1), + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + feature_group_count=feature_group_count, + ) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) @@ -168,56 +172,58 @@ def __call__(self, inputs, paddings, train): pad_len = (input_length + stride - 1) // stride * stride - input_length out_padding = jax.lax.conv_general_dilated( - lhs=paddings[:, :, None], - rhs=jnp.ones([1, 1, 1]), - window_strides=self.filter_stride[:1], - padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) + lhs=paddings[:, :, None], + rhs=jnp.ones([1, 1, 1]), + window_strides=self.filter_stride[:1], + padding=[(0, pad_len)], + dimension_numbers=('NHC', 'HIO', 'NHC'), + ) out_padding = jnp.squeeze(out_padding, axis=-1) # Mask outputs by correct paddings to ensure padded elements in inputs map # to padded value in outputs. - outputs = outputs * (1.0 - - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + outputs = outputs * ( + 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1) + ) return outputs, out_padding class FeedForwardModule(nn.Module): """Feedforward block of conformer layer.""" + config: DeepspeechConfig @nn.compact - def __call__(self, inputs, input_paddings=None, train=False): + def __call__( + self, inputs, input_paddings=None, train=False, dropout_rate=DROPOUT_RATE + ): padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config if config.layernorm_everywhere: inputs = LayerNorm(config.encoder_dim)(inputs) else: - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, - input_paddings, - train) - inputs = nn.Dense( + inputs = BatchNorm( config.encoder_dim, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon, + )(inputs, input_paddings, train) + inputs = nn.Dense( + config.encoder_dim, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) if config.use_tanh: inputs = nn.tanh(inputs) else: inputs = nn.relu(inputs) inputs *= padding_mask - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - inputs = nn.Dropout(rate=feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs = Dropout(rate=dropout_rate)( + inputs, deterministic=not train, rate=dropout_rate + ) return inputs @@ -232,6 +238,7 @@ class LayerNorm(nn.Module): zeros, this differs from default flax implementation of multiplying by scale and initializing to ones. """ + dim: int = 0 epsilon: float = 1e-6 @@ -245,7 +252,7 @@ def __call__(self, inputs): var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True) normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) + normed_inputs *= 1 + self.scale normed_inputs += self.bias return normed_inputs @@ -263,6 +270,7 @@ class BatchNorm(nn.Module): and the corresponding defaults for momentum and epsilon have been copied over from lingvo. """ + encoder_dim: int = 0 dtype: Any = jnp.float32 batch_norm_momentum: float = 0.999 @@ -272,14 +280,12 @@ def setup(self): dim = self.encoder_dim dtype = self.dtype - self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), - dim) - self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), - dim) + self.ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim + ) + self.ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim + ) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @@ -308,7 +314,8 @@ def __call__(self, inputs, input_paddings=None, train=False): mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True + ) sum_v = jax.lax.psum(sum_v, axis_name='batch') count_v = jax.lax.psum(count_v, axis_name='batch') @@ -345,15 +352,14 @@ class CudnnLSTM(nn.Module): @nn.compact def __call__( - self, - inputs: Array, - segmentation_mask: Optional[Array] = None, - return_carry: Optional[bool] = None, - deterministic: bool = False, - initial_states: Optional[Tuple[Array, Array]] = None, - use_cuda: bool = True, + self, + inputs: Array, + segmentation_mask: Optional[Array] = None, + return_carry: Optional[bool] = None, + deterministic: bool = False, + initial_states: Optional[Tuple[Array, Array]] = None, + use_cuda: bool = True, ) -> Union[Array, Tuple[Array, Carry]]: - if jax.devices()[0].platform != 'gpu': use_cuda = False @@ -363,22 +369,22 @@ def __call__( dropout = 0.0 if deterministic else self.dropout_rate weights = self.param( - 'weights', - rnn.init_lstm_weight, - input_size, - self.features, - self.num_layers, - self.bidirectional, + 'weights', + rnn.init_lstm_weight, + input_size, + self.features, + self.num_layers, + self.bidirectional, ) if initial_states is None: h_0 = jnp.zeros( - (num_directions * self.num_layers, batch_size, self.features), - jnp.float32, + (num_directions * self.num_layers, batch_size, self.features), + jnp.float32, ) c_0 = jnp.zeros( - (num_directions * self.num_layers, batch_size, self.features), - jnp.float32, + (num_directions * self.num_layers, batch_size, self.features), + jnp.float32, ) else: h_0, c_0 = initial_states @@ -390,20 +396,35 @@ def __call__( if use_cuda: y, h, c = rnn.lstm( - x=inputs, h_0=h_0, c_0=c_0, weights=weights, - seq_lengths=seq_lengths, input_size=input_size, - hidden_size=self.features, num_layers=self.num_layers, - dropout=dropout, bidirectional=self.bidirectional, + x=inputs, + h_0=h_0, + c_0=c_0, + weights=weights, + seq_lengths=seq_lengths, + input_size=input_size, + hidden_size=self.features, + num_layers=self.num_layers, + dropout=dropout, + bidirectional=self.bidirectional, ) else: weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights( - weights, input_size) + weights, input_size + ) y, h, c = rnn.lstm_ref( - x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh, - b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths, - input_size=input_size, hidden_size=self.features, - num_layers=self.num_layers, dropout=dropout, - bidirectional=self.bidirectional, + x=inputs, + h_0=h_0, + c_0=c_0, + W_ih=weight_ih, + W_hh=weight_hh, + b_ih=bias_ih, + b_hh=bias_hh, + seq_lengths=seq_lengths, + input_size=input_size, + hidden_size=self.features, + num_layers=self.num_layers, + dropout=dropout, + bidirectional=self.bidirectional, ) if return_carry: @@ -413,21 +434,22 @@ def __call__( @nn.nowrap def unpack_weights( - self, weights: Array, input_size: int + self, weights: Array, input_size: int ) -> Tuple[ - Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]: + Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array] + ]: return jax.experimental.rnn.unpack_lstm_weights( - weights, - input_size, - self.features, - self.num_layers, - self.bidirectional, + weights, + input_size, + self.features, + self.num_layers, + self.bidirectional, ) class BatchRNN(nn.Module): - """Implements a single deepspeech encoder layer. - """ + """Implements a single deepspeech encoder layer.""" + config: DeepspeechConfig @nn.compact @@ -437,16 +459,17 @@ def __call__(self, inputs, input_paddings, train): if config.layernorm_everywhere: inputs = LayerNorm(config.encoder_dim)(inputs) else: - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, - input_paddings, - train) + inputs = BatchNorm( + config.encoder_dim, + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon, + )(inputs, input_paddings, train) output = CudnnLSTM( - features=config.encoder_dim // 2, - bidirectional=config.bidirectional, - num_layers=1)(inputs, input_paddings) + features=config.encoder_dim // 2, + bidirectional=config.bidirectional, + num_layers=1, + )(inputs, input_paddings) return output @@ -458,22 +481,23 @@ class Deepspeech(nn.Module): for each time step. The output is then fed into a CTC loss which eliminates the need for alignment with targets. """ + config: DeepspeechConfig def setup(self): config = self.config self.specaug = spectrum_augmenter.SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE): config = self.config outputs = inputs @@ -482,10 +506,10 @@ def __call__(self, inputs, input_paddings, train): # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs, - output_paddings) + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + )(outputs, output_paddings) # Ablate random parts of input along temporal and frequency dimension # following the specaug procedure in https://arxiv.org/abs/1904.08779. @@ -493,8 +517,9 @@ def __call__(self, inputs, input_paddings, train): outputs, output_paddings = self.specaug(outputs, output_paddings) # Subsample input by a factor of 4 by performing strided convolutions. - outputs, output_paddings = Subsample( - config=config)(outputs, output_paddings, train) + outputs, output_paddings = Subsample(config=config)( + outputs, output_paddings, train, dropout_rate=dropout_rate + ) # Run the lstm layers. for _ in range(config.num_lstm_layers): @@ -506,20 +531,21 @@ def __call__(self, inputs, input_paddings, train): for _ in range(config.num_ffn_layers): if config.enable_residual_connections: outputs = outputs + FeedForwardModule(config=self.config)( - outputs, output_paddings, train) + outputs, output_paddings, train + ) else: - outputs = FeedForwardModule(config=self.config)(outputs, - output_paddings, - train) + outputs = FeedForwardModule(config=self.config)( + outputs, output_paddings, train, dropout_rate=dropout_rate + ) # Run the decoder which in this case is a trivial projection layer. if config.enable_decoder_layer_norm: outputs = LayerNorm(config.encoder_dim)(outputs) outputs = nn.Dense( - config.vocab_size, - use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - outputs) + config.vocab_size, + use_bias=True, + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) return outputs, output_paddings diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..b93934abf 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,40 +1,29 @@ import functools from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import numpy as np +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerWorkload +from algoperf import param_utils, spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload, +) from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. - """ + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + """Deepspeech model init function.""" model_config = models.DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count, + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] @@ -42,12 +31,17 @@ def init_model_fn( model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) - params_rng, dropout_rng = jax.random.split(rng, 2) - variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, - *fake_input_batch) + params_rng, _ = jax.random.split(rng, 2) + variables = model_init_fn( + { + 'params': params_rng, + }, + *fake_input_batch, + ) - model_state = variables[ - 'batch_stats'] if not self.layernorm_everywhere else {} + model_state = ( + variables['batch_stats'] if not self.layernorm_everywhere else {} + ) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -56,34 +50,34 @@ def init_model_fn( return params, model_state def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None, + dropout_rate: Optional[bool] = models.DROPOUT_RATE, ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats']) + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout': rng}, + mutable=['batch_stats'], + dropout_rate=dropout_rate, + ) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) + variables, inputs, input_paddings, train=False, mutable=False + ) return (logits, logit_paddings), model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -132,7 +126,6 @@ def time_mask_count(self) -> int: class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): - @property def use_tanh(self) -> bool: return True @@ -147,7 +140,6 @@ def test_target_value(self) -> float: class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): - @property def enable_residual_connections(self) -> bool: return False @@ -161,9 +153,9 @@ def test_target_value(self) -> float: return 0.079297 -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload - ): - +class LibriSpeechDeepSpeechNormAndSpecAugWorkload( + LibriSpeechDeepSpeechWorkload +): @property def eval_batch_size(self) -> int: return 128 diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 84d317326..aab75da63 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -2,26 +2,30 @@ https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py. """ -from dataclasses import dataclass import os -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Tuple import torch -from torch import nn import torch.distributed.nn as dist_nn import torch.nn.functional as F +from torch import nn -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - preprocessor -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ - SpecAug +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import ( + preprocessor, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import ( + SpecAug, +) USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +DROPOUT_RATE = 0.1 @dataclass class DeepspeechConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 1024 encoder_dim: int = 512 num_lstm_layers: int = 6 @@ -38,10 +42,6 @@ class DeepspeechConfig: use_dynamic_time_mask_max_frames: bool = True batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 - # If None, defaults to 0.1. - input_dropout_rate: Optional[float] = 0.1 - # If None, defaults to 0.1. - feed_forward_dropout_rate: Optional[float] = 0.1 enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True @@ -50,7 +50,6 @@ class DeepspeechConfig: class LayerNorm(nn.Module): - def __init__(self, dim, epsilon=1e-6): super().__init__() self.dim = dim @@ -64,14 +63,13 @@ def forward(self, x): var = x.var(dim=-1, unbiased=False, keepdims=True) normed_x = (x - mean) * torch.rsqrt(var + self.epsilon) - normed_x *= (1 + self.scale) + normed_x *= 1 + self.scale normed_x += self.bias return normed_x class Subsample(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() encoder_dim = config.encoder_dim @@ -79,21 +77,17 @@ def __init__(self, config: DeepspeechConfig): self.encoder_dim = encoder_dim self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) + input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh + ) self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, - output_channels=encoder_dim, - use_tanh=config.use_tanh) + input_channels=encoder_dim, + output_channels=encoder_dim, + use_tanh=config.use_tanh, + ) self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) - if config.input_dropout_rate is None: - input_dropout_rate = 0.1 - else: - input_dropout_rate = config.input_dropout_rate - self.dropout = nn.Dropout(p=input_dropout_rate) - - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): output_paddings = input_paddings outputs = inputs[:, None, :, :] @@ -101,26 +95,27 @@ def forward(self, inputs, input_paddings): outputs, output_paddings = self.conv2(outputs, output_paddings) batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape - outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size, - subsampled_lengths, - subsampled_dims * channels) + outputs = outputs.permute(0, 2, 3, 1).reshape( + batch_size, subsampled_lengths, subsampled_dims * channels + ) outputs = self.lin(outputs) - outputs = self.dropout(outputs) + outputs = F.dropout(outputs, dropout_rate, training=self.training) return outputs, output_paddings class Conv2dSubsampling(nn.Module): - - def __init__(self, - input_channels: int, - output_channels: int, - filter_stride: Tuple[int] = (2, 2), - padding: str = 'SAME', - batch_norm_momentum: float = 0.999, - batch_norm_epsilon: float = 0.001, - use_tanh: bool = False): + def __init__( + self, + input_channels: int, + output_channels: int, + filter_stride: Tuple[int] = (2, 2), + padding: str = 'SAME', + batch_norm_momentum: float = 0.999, + batch_norm_epsilon: float = 0.001, + use_tanh: bool = False, + ): super().__init__() self.input_channels = input_channels @@ -131,7 +126,8 @@ def __init__(self, self.filter_shape = (output_channels, input_channels, 3, 3) self.kernel = nn.Parameter( - nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) + nn.init.xavier_uniform_(torch.empty(*self.filter_shape)) + ) self.bias = nn.Parameter(torch.zeros(output_channels)) self.use_tanh = use_tanh @@ -162,12 +158,13 @@ def forward(self, inputs, paddings): else: in_ = inputs outputs = F.conv2d( - input=in_, - weight=self.kernel, - bias=self.bias, - stride=self.filter_stride, - dilation=(1, 1), - groups=groups) + input=in_, + weight=self.kernel, + bias=self.bias, + stride=self.filter_stride, + dilation=(1, 1), + groups=groups, + ) if self.use_tanh: outputs = F.tanh(outputs) @@ -178,21 +175,24 @@ def forward(self, inputs, paddings): stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length out_padding = F.conv1d( - input=torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) + input=torch.cat( + [ + paddings[:, None, :], + torch.zeros( + size=(paddings.shape[0], 1, pad_len), device=paddings.device + ), ], - dim=2), - weight=torch.ones([1, 1, 1], device=paddings.device), - stride=self.filter_stride[:1]) + dim=2, + ), + weight=torch.ones([1, 1, 1], device=paddings.device), + stride=self.filter_stride[:1], + ) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) return outputs, out_padding class FeedForwardModule(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config @@ -201,17 +201,13 @@ def __init__(self, config: DeepspeechConfig): self.normalization_layer = LayerNorm(config.encoder_dim) else: self.bn_normalization_layer = BatchNorm( - dim=config.encoder_dim, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon) + dim=config.encoder_dim, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon, + ) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) - if config.feed_forward_dropout_rate is None: - feed_forward_dropout_rate = 0.1 - else: - feed_forward_dropout_rate = config.feed_forward_dropout_rate - self.dropout = nn.Dropout(p=feed_forward_dropout_rate) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate): padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) @@ -226,13 +222,12 @@ def forward(self, inputs, input_paddings): inputs = F.relu(inputs) inputs = inputs * padding_mask - inputs = self.dropout(inputs) + inputs = F.dropout(inputs, dropout_rate, training=self.training) return inputs class BatchNorm(nn.Module): - def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): super().__init__() running_mean = torch.zeros(dim) @@ -247,8 +242,8 @@ def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon): self.dim = dim def forward(self, inputs, input_paddings): - #inputs: NHD - #padding: NH + # inputs: NHD + # padding: NH mask = 1 - input_paddings[:, :, None] if self.training: count = mask.sum() @@ -265,9 +260,11 @@ def forward(self, inputs, input_paddings): var = sum_ / count self.running_mean = (1 - self.momentum) * self.running_mean + ( - self.momentum) * mean.detach() + self.momentum + ) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( - self.momentum) * var.detach() + self.momentum + ) * var.detach() else: mean = self.running_mean var = self.running_var @@ -278,7 +275,6 @@ def forward(self, inputs, input_paddings): class BatchRNN(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config @@ -290,19 +286,23 @@ def __init__(self, config: DeepspeechConfig): if config.layernorm_everywhere: self.normalization_layer = LayerNorm(config.encoder_dim) else: - self.bn_normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) + self.bn_normalization_layer = BatchNorm( + config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon, + ) if bidirectional: self.lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size // 2, - bidirectional=True, - batch_first=True) + input_size=input_size, + hidden_size=hidden_size // 2, + bidirectional=True, + batch_first=True, + ) else: self.lstm = nn.LSTM( - input_size=input_size, hidden_size=hidden_size, batch_first=True) + input_size=input_size, hidden_size=hidden_size, batch_first=True + ) def forward(self, inputs, input_paddings): if self.config.layernorm_everywhere: @@ -311,50 +311,59 @@ def forward(self, inputs, input_paddings): inputs = self.bn_normalization_layer(inputs, input_paddings) lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( - inputs, lengths, batch_first=True, enforce_sorted=False) + inputs, lengths, batch_first=True, enforce_sorted=False + ) packed_outputs, _ = self.lstm(packed_inputs) outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( - packed_outputs, batch_first=True) + packed_outputs, batch_first=True + ) if outputs.shape[1] < inputs.shape[1]: - outputs = torch.cat([ + outputs = torch.cat( + [ outputs, torch.zeros( - size=(outputs.shape[0], - inputs.shape[1] - outputs.shape[1], - outputs.shape[2]), - device=outputs.device) - ], - dim=1) + size=( + outputs.shape[0], + inputs.shape[1] - outputs.shape[1], + outputs.shape[2], + ), + device=outputs.device, + ), + ], + dim=1, + ) return outputs class DeepspeechEncoderDecoder(nn.Module): - def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config self.specaug = SpecAug( - freq_mask_count=config.freq_mask_count, - freq_mask_max_bins=config.freq_mask_max_bins, - time_mask_count=config.time_mask_count, - time_mask_max_frames=config.time_mask_max_frames, - time_mask_max_ratio=config.time_mask_max_ratio, - time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + freq_mask_count=config.freq_mask_count, + freq_mask_max_bins=config.freq_mask_max_bins, + time_mask_count=config.time_mask_count, + time_mask_max_frames=config.time_mask_max_frames, + time_mask_max_ratio=config.time_mask_max_ratio, + time_masks_per_frame=config.time_masks_per_frame, + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) preprocessing_config = preprocessor.PreprocessorConfig() self.preprocessor = preprocessor.MelFilterbankFrontend( - preprocessing_config, - per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR) + preprocessing_config, + per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + ) self.subsample = Subsample(config=config) self.lstms = nn.ModuleList( - [BatchRNN(config) for _ in range(config.num_lstm_layers)]) + [BatchRNN(config) for _ in range(config.num_lstm_layers)] + ) self.ffns = nn.ModuleList( - [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) + [FeedForwardModule(config) for _ in range(config.num_ffn_layers)] + ) if config.enable_decoder_layer_norm: self.ln = LayerNorm(config.encoder_dim) @@ -363,14 +372,16 @@ def __init__(self, config: DeepspeechConfig): self.lin = nn.Linear(config.encoder_dim, config.vocab_size) - def forward(self, inputs, input_paddings): + def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE): outputs = inputs output_paddings = input_paddings outputs, output_paddings = self.preprocessor(outputs, output_paddings) if self.training and self.config.use_specaug: outputs, output_paddings = self.specaug(outputs, output_paddings) - outputs, output_paddings = self.subsample(outputs, output_paddings) + outputs, output_paddings = self.subsample( + outputs, output_paddings, dropout_rate + ) for idx in range(self.config.num_lstm_layers): if self.config.enable_residual_connections: outputs = outputs + self.lstms[idx](outputs, output_paddings) @@ -379,9 +390,11 @@ def forward(self, inputs, input_paddings): for idx in range(self.config.num_ffn_layers): if self.config.enable_residual_connections: - outputs = outputs + self.ffns[idx](outputs, output_paddings) + outputs = outputs + self.ffns[idx]( + outputs, output_paddings, dropout_rate + ) else: - outputs = self.ffns[idx](outputs, output_paddings) + outputs = self.ffns[idx](outputs, output_paddings, dropout_rate) if self.config.enable_decoder_layer_norm: outputs = self.ln(outputs) diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e5387f5cb..672f3440f 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,19 +1,21 @@ -from typing import Optional +from typing import Dict, Tuple import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - initialize -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechEncoderDecoder +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + initialize, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + DeepspeechConfig, + DeepspeechEncoderDecoder, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -21,29 +23,20 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): - - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Deepspeech model init function. - - Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate - as input_dropout_rate. - """ + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + """Deepspeech model init function.""" torch.random.manual_seed(rng[0]) model = DeepspeechEncoderDecoder( - DeepspeechConfig( - feed_forward_dropout_rate=dropout_rate, - use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate, - use_tanh=self.use_tanh, - enable_residual_connections=self.enable_residual_connections, - enable_decoder_layer_norm=self.enable_decoder_layer_norm, - layernorm_everywhere=self.layernorm_everywhere, - freq_mask_count=self.freq_mask_count, - time_mask_count=self.time_mask_count)).eval() + DeepspeechConfig( + use_specaug=self.use_specaug, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) + ).eval() self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') # Run model once to initialize lazy layers. t = MAX_INPUT_LENGTH @@ -63,6 +56,28 @@ def init_model_fn( model = torch.nn.DataParallel(model) return model, None + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + # override super method, changing only the default dropout_rate + # pylint: disable=useless-parent-delegation + return super().model_fn( + params, + augmented_and_preprocessed_input_batch, + model_state, + mode, + rng, + update_batch_norm, + dropout_rate, + ) + def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] @@ -109,7 +124,6 @@ def time_mask_count(self) -> int: class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): - @property def use_tanh(self) -> bool: return True @@ -124,7 +138,6 @@ def test_target_value(self) -> float: class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): - @property def enable_residual_connections(self) -> bool: return False @@ -138,9 +151,9 @@ def test_target_value(self) -> float: return 0.079297 -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload - ): - +class LibriSpeechDeepSpeechNormAndSpecAugWorkload( + LibriSpeechDeepSpeechWorkload +): @property def eval_batch_size(self) -> int: return 128 diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..0d192cbf5 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -3,20 +3,18 @@ import functools from typing import Any, Dict, Optional, Tuple -from flax import jax_utils -from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from flax import linen as nn +from jax import lax -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): - @nn.compact def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: del train @@ -31,19 +29,12 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """Dropout is unused.""" - del dropout_rate - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) self._model = _Model() - initial_params = self._model.init({'params': rng}, init_val, - train=True)['params'] + initial_params = self._model.init({'params': rng}, init_val, train=True)[ + 'params' + ] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) return jax_utils.replicate(initial_params), None @@ -52,31 +43,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm train = mode == spec.ForwardPassMode.TRAIN logits_batch = self._model.apply( - {'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - train=train) + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + train=train, + ) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -86,7 +80,8 @@ def loss_fn( one_hot_targets = jax.nn.one_hot(label_batch, 10) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) per_example_losses = -jnp.sum( - smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + smoothed_targets * nn.log_softmax(logits_batch), axis=-1 + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -95,41 +90,45 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) weights = batch.get('weights') if weights is None: weights = jnp.ones(len(logits)) accuracy = jnp.sum( - (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) + (jnp.argmax(logits, axis=-1) == batch['targets']) * weights + ) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algoperf/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py index 780e1bca0..b58898703 100644 --- a/algoperf/workloads/mnist/mnist_pytorch/workload.py +++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py @@ -1,18 +1,16 @@ """MNIST workload implemented in PyTorch.""" -from collections import OrderedDict import contextlib +from collections import OrderedDict from typing import Any, Dict, Iterator, Optional, Tuple import torch -from torch import nn import torch.distributed as dist import torch.nn.functional as F +from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import init_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import init_utils, param_utils, spec from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.mnist.workload import BaseMnistWorkload @@ -20,18 +18,20 @@ class _Model(nn.Module): - def __init__(self) -> None: super().__init__() input_size = 28 * 28 num_hidden = 128 num_classes = 10 self.net = nn.Sequential( - OrderedDict([('layer1', - torch.nn.Linear(input_size, num_hidden, bias=True)), - ('layer1_sig', torch.nn.Sigmoid()), - ('layer2', - torch.nn.Linear(num_hidden, num_classes, bias=True))])) + OrderedDict( + [ + ('layer1', torch.nn.Linear(input_size, num_hidden, bias=True)), + ('layer1_sig', torch.nn.Sigmoid()), + ('layer2', torch.nn.Linear(num_hidden, num_classes, bias=True)), + ] + ) + ) def reset_parameters(self) -> None: for m in self.net.modules(): @@ -44,16 +44,16 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class MnistWorkload(BaseMnistWorkload): - def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del cache if N_GPUS != 0: per_device_batch_size = int(global_batch_size / N_GPUS) @@ -63,22 +63,27 @@ def _build_input_queue( # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset) + np_iter = super()._build_input_queue( + data_rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset, + ) while True: if RANK == 0: batch = next(np_iter) # pylint: disable=stop-iteration-return inputs = torch.as_tensor( - batch['inputs'], dtype=torch.float32, device=DEVICE) + batch['inputs'], dtype=torch.float32, device=DEVICE + ) targets = torch.as_tensor( - batch['targets'], dtype=torch.long, device=DEVICE) + batch['targets'], dtype=torch.long, device=DEVICE + ) if 'weights' in batch: weights = torch.as_tensor( - batch['weights'], dtype=torch.bool, device=DEVICE) + batch['weights'], dtype=torch.bool, device=DEVICE + ) else: weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE) # Send batch to other devices when using DDP. @@ -94,34 +99,37 @@ def _build_input_queue( targets = targets.view(-1, *targets.shape[2:]) weights = weights.view(-1, *weights.shape[2:]) else: - inputs = torch.empty((N_GPUS, per_device_batch_size, 28, 28, 1), - dtype=torch.float32, - device=DEVICE) + inputs = torch.empty( + (N_GPUS, per_device_batch_size, 28, 28, 1), + dtype=torch.float32, + device=DEVICE, + ) dist.broadcast(inputs, src=0) inputs = inputs[RANK] - targets = torch.empty((N_GPUS, per_device_batch_size), - dtype=torch.long, - device=DEVICE) + targets = torch.empty( + (N_GPUS, per_device_batch_size), dtype=torch.long, device=DEVICE + ) dist.broadcast(targets, src=0) targets = targets[RANK] - weights = torch.empty((N_GPUS, per_device_batch_size), - dtype=torch.bool, - device=DEVICE) + weights = torch.empty( + (N_GPUS, per_device_batch_size), dtype=torch.bool, device=DEVICE + ) dist.broadcast(weights, src=0) weights = weights[RANK] batch = { - 'inputs': inputs.permute(0, 3, 1, 2), - 'targets': targets, - 'weights': weights, + 'inputs': inputs.permute(0, 3, 1, 2), + 'targets': targets, + 'weights': weights, } yield batch def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + self, + rng: spec.RandomState, + dropout_rate: Optional[float] = None, + aux_dropout_rate: Optional[float] = None, + ) -> spec.ModelInitState: """Dropout is unused.""" del dropout_rate del aux_dropout_rate @@ -149,13 +157,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['net.layer2.weight', 'net_layer2.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -163,8 +172,8 @@ def model_fn( if mode == spec.ForwardPassMode.EVAL: model.eval() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) @@ -173,11 +182,12 @@ def model_fn( # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -185,10 +195,11 @@ def loss_fn( (not synced across devices). """ per_example_losses = F.cross_entropy( - logits_batch, - label_batch, - reduction='none', - label_smoothing=label_smoothing) + logits_batch, + label_batch, + reduction='none', + label_smoothing=label_smoothing, + ) # `mask_batch` is assumed to be shape [batch]. if mask_batch is not None: per_example_losses *= mask_batch @@ -197,25 +208,27 @@ def loss_fn( n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, } def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) targets = batch['targets'] weights = batch.get('weights') if weights is None: @@ -227,8 +240,8 @@ def _eval_model( return {'accuracy': accuracy, 'loss': summed_loss} def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py index f53aadd0b..38006b9ac 100644 --- a/algoperf/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -10,10 +10,9 @@ import tensorflow_datasets as tfds import torch -from algoperf import data_utils -from algoperf import spec -from algoperf.pytorch_utils import pytorch_setup import algoperf.random_utils as prng +from algoperf import data_utils, spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, _, _, _ = pytorch_setup() @@ -23,16 +22,17 @@ def _normalize(image: spec.Tensor, mean: float, stddev: float) -> spec.Tensor: def _build_mnist_dataset( - data_rng: jax.random.PRNGKey, - num_train_examples: int, - num_validation_examples: int, - train_mean: float, - train_stddev: float, - split: str, - data_dir: str, - global_batch_size: int, - cache: bool = False, - repeat_final_dataset: bool = True) -> Iterator[Dict[str, spec.Tensor]]: + data_rng: jax.random.PRNGKey, + num_train_examples: int, + num_validation_examples: int, + train_mean: float, + train_stddev: float, + split: str, + data_dir: str, + global_batch_size: int, + cache: bool = False, + repeat_final_dataset: bool = True, +) -> Iterator[Dict[str, spec.Tensor]]: shuffle = split in ['train', 'eval_train'] assert num_train_examples + num_validation_examples == 60000 if shuffle: @@ -42,12 +42,14 @@ def _build_mnist_dataset( else: tfds_split = 'test' ds = tfds.load( - 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir) + 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir + ) ds = ds.map( - lambda x: { - 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + lambda x: { + 'inputs': _normalize(x['image'], train_mean, train_stddev), + 'targets': x['label'], + } + ) is_train = split == 'train' if cache: @@ -62,22 +64,23 @@ def _build_mnist_dataset( ds = ds.repeat() ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) return iter(ds) class BaseMnistWorkload(spec.Workload): - @property def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'accuracy' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target( + self, eval_result: Dict[str, float] + ) -> bool: return eval_result['validation/accuracy'] > self.validation_target_value @property @@ -104,8 +107,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -138,31 +142,33 @@ def eval_period_time_sec(self) -> int: @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" def _build_input_queue( - self, - data_rng: spec.RandomState, - split: str, - data_dir: str, - global_batch_size: int, - cache: Optional[bool] = None, - repeat_final_dataset: Optional[bool] = None, - num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None, + ) -> Iterator[Dict[str, spec.Tensor]]: del num_batches ds = _build_mnist_dataset( - data_rng=data_rng, - num_train_examples=self.num_train_examples, - num_validation_examples=self.num_validation_examples, - train_mean=self.train_mean, - train_stddev=self.train_stddev, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - cache=cache, - repeat_final_dataset=repeat_final_dataset) + data_rng=data_rng, + num_train_examples=self.num_train_examples, + num_validation_examples=self.num_validation_examples, + train_mean=self.train_mean, + train_stddev=self.train_stddev, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + cache=cache, + repeat_final_dataset=repeat_final_dataset, + ) return ds @property @@ -173,49 +179,52 @@ def step_hint(self) -> int: return 7813 def _eval_model( - self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: raise NotImplementedError - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng=data_rng, - split=split, - data_dir=data_dir, - global_batch_size=global_batch_size, - cache=True, - repeat_final_dataset=True) + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + cache=True, + repeat_final_dataset=True, + ) total_metrics = { - 'accuracy': 0., - 'loss': 0., + 'accuracy': 0.0, + 'loss': 0.0, } num_batches = int(math.ceil(num_examples / global_batch_size)) num_devices = max(torch.cuda.device_count(), jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) per_device_model_rngs = prng.split(model_rng, num_devices) - batch_metrics = self._eval_model(params, - batch, - model_state, - per_device_model_rngs) + batch_metrics = self._eval_model( + params, batch, model_state, per_device_model_rngs + ) total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() + k: v + batch_metrics[k] for k, v in total_metrics.items() } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..79a4ddc4a 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -14,10 +14,10 @@ AVG_EDGES_PER_GRAPH = 56 TFDS_SPLIT_NAME = { - 'train': 'train', - 'eval_train': 'train', - 'validation': 'validation', - 'test': 'test', + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', + 'test': 'test', } @@ -33,11 +33,12 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng) dataset = tfds.load( - 'ogbg_molpcba:0.1.3', - split=TFDS_SPLIT_NAME[split], - shuffle_files=should_shuffle, - read_config=read_config, - data_dir=data_dir) + 'ogbg_molpcba:0.1.3', + split=TFDS_SPLIT_NAME[split], + shuffle_files=should_shuffle, + read_config=read_config, + data_dir=data_dir, + ) if should_shuffle: dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15) @@ -62,16 +63,17 @@ def _to_jraph(example): receivers = edge_index[:, 1] return jraph.GraphsTuple( - n_node=num_nodes, - n_edge=np.array([len(edge_index) * 2]), - nodes=node_feat, - edges=np.concatenate([edge_feat, edge_feat]), - # Make the edges bidirectional - senders=np.concatenate([senders, receivers]), - receivers=np.concatenate([receivers, senders]), - # Keep the labels with the graph for batching. They will be removed - # in the processed batch. - globals=np.expand_dims(labels, axis=0)) + n_node=num_nodes, + n_edge=np.array([len(edge_index) * 2]), + nodes=node_feat, + edges=np.concatenate([edge_feat, edge_feat]), + # Make the edges bidirectional + senders=np.concatenate([senders, receivers]), + receivers=np.concatenate([receivers, senders]), + # Keep the labels with the graph for batching. They will be removed + # in the processed batch. + globals=np.expand_dims(labels, axis=0), + ) def _get_weights_by_nan_and_padding(labels, padding_mask): @@ -123,10 +125,9 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): max_n_graphs = per_device_batch_size jraph_iter = map(_to_jraph, dataset_iter) - batched_iter = jraph.dynamically_batch(jraph_iter, - max_n_nodes + 1, - max_n_edges, - max_n_graphs + 1) + batched_iter = jraph.dynamically_batch( + jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1 + ) count = 0 graphs_shards = [] @@ -141,7 +142,8 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): graph = batched_graph._replace(globals={}) replaced_labels, weights = _get_weights_by_nan_and_padding( - labels, jraph.get_graph_padding_mask(graph)) + labels, jraph.get_graph_padding_mask(graph) + ) graphs_shards.append(graph) labels_shards.append(replaced_labels) @@ -156,9 +158,9 @@ def f(x): labels_shards = f(labels_shards) weights_shards = f(weights_shards) yield { - 'inputs': graphs_shards, - 'targets': labels_shards, - 'weights': weights_shards, + 'inputs': graphs_shards, + 'targets': labels_shards, + 'weights': weights_shards, } count = 0 @@ -170,5 +172,6 @@ def f(x): def get_dataset_iter(split, data_rng, data_dir, global_batch_size): shuffle = split in ['train', 'eval_train'] ds = _load_dataset( - split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir) + split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir + ) return _get_batch_iterator(iter(ds), global_batch_size) diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py index 55f83d905..7d41204c8 100644 --- a/algoperf/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -2,24 +2,23 @@ # https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py from typing import Any -from clu import metrics import flax import jax import jax.numpy as jnp import numpy as np -from sklearn.metrics import average_precision_score import torch import torch.distributed as dist +from clu import metrics +from sklearn.metrics import average_precision_score from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() -def predictions_match_labels(*, - logits: jnp.ndarray, - labels: jnp.ndarray, - **kwargs) -> jnp.ndarray: +def predictions_match_labels( + *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs +) -> jnp.ndarray: """Returns a binary array indicating where predictions match the labels.""" del kwargs # Unused. preds = logits > 0 @@ -28,7 +27,8 @@ def predictions_match_labels(*, @flax.struct.dataclass class MeanAveragePrecision( - metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))): + metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask')) +): """Computes the mean average precision (mAP) over different tasks.""" def compute(self): @@ -37,6 +37,7 @@ def compute(self): labels = values['labels'] logits = values['logits'] mask = values['mask'] + sigmoid = jax.nn.sigmoid if USE_PYTORCH_DDP: # Sync labels, logits, and masks across devices. @@ -49,9 +50,14 @@ def compute(self): all_values[idx] = torch.cat(all_tensors).cpu().numpy() labels, logits, mask = all_values + def sigmoid_np(x): + return 1 / (1 + np.exp(-x)) + + sigmoid = sigmoid_np + mask = mask.astype(bool) - probs = jax.nn.sigmoid(logits) + probs = sigmoid(logits) num_tasks = labels.shape[1] average_precisions = np.full(num_tasks, np.nan) @@ -62,7 +68,8 @@ def compute(self): if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0: is_labeled = mask[:, task] average_precisions[task] = average_precision_score( - labels[is_labeled, task], probs[is_labeled, task]) + labels[is_labeled, task], probs[is_labeled, task] + ) # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. if np.isnan(average_precisions).all(): diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..db1ca416c 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -1,21 +1,24 @@ # Forked from the init2winit implementation here # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from typing import Optional, Tuple +from typing import Tuple -from flax import linen as nn import jax.numpy as jnp import jraph +from flax import linen as nn +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.1 -def _make_embed(latent_dim, name): +def _make_embed(latent_dim, name): def make_fn(inputs): return nn.Dense(features=latent_dim, name=name)(inputs) return make_fn -def _make_mlp(hidden_dims, dropout, activation_fn): +def _make_mlp(hidden_dims, activation_fn, train, dropout_rate=DROPOUT_RATE): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -25,7 +28,9 @@ def make_fn(inputs): x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) x = activation_fn(x) - x = dropout(x) + x = Dropout(rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate + ) return x return make_fn @@ -36,28 +41,23 @@ class GNN(nn.Module): The model assumes the input data is a jraph.GraphsTuple without global variables. The final prediction will be encoded in the globals. """ + num_outputs: int latent_dim: int = 256 hidden_dims: Tuple[int] = (256,) - # If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 activation_fn_name: str = 'relu' @nn.compact - def __call__(self, graph, train): - if self.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = self.dropout_rate - dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) - + def __call__(self, graph, train, dropout_rate=DROPOUT_RATE): graph = graph._replace( - globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) + globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]) + ) embedder = jraph.GraphMapFeatures( - embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'), - embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) + embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'), + embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'), + ) graph = embedder(graph) if self.activation_fn_name == 'relu': @@ -68,16 +68,30 @@ def __call__(self, graph, train): activation_fn = nn.silu else: raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') + f'Invalid activation function name: {self.activation_fn_name}' + ) for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( - update_edge_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), - update_node_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn), - update_global_fn=_make_mlp( - self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) + update_edge_fn=_make_mlp( + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate, + ), + update_node_fn=_make_mlp( + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate, + ), + update_global_fn=_make_mlp( + self.hidden_dims, + activation_fn=activation_fn, + train=train, + dropout_rate=dropout_rate, + ), + ) graph = net(graph) diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..8471fcdcc 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -1,47 +1,41 @@ """OGBG workload implemented in Jax.""" + import functools -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import jraph import optax +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.ogbg import metrics from algoperf.workloads.ogbg.ogbg_jax import models from algoperf.workloads.ogbg.workload import BaseOgbgWorkload class OgbgWorkload(BaseOgbgWorkload): - - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - del aux_dropout_rate - rng, params_rng, dropout_rng = jax.random.split(rng, 3) + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + rng, params_rng = jax.random.split(rng, 2) self._model = models.GNN( - self._num_outputs, - dropout_rate=dropout_rate, - activation_fn_name=self.activation_fn_name, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + self._num_outputs, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + ) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, self._num_outputs)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) - params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, self._num_outputs)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0]), + ) + params = init_fn({'params': params_rng}, fake_batch) params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -51,37 +45,45 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del update_batch_norm # No BN in the GNN model. if model_state is not None: raise ValueError( - f'Expected model_state to be None, received {model_state}.') + f'Expected model_state to be None, received {model_state}.' + ) train = mode == spec.ForwardPassMode.TRAIN - logits = self._model.apply({'params': params}, - augmented_and_preprocessed_input_batch['inputs'], - rngs={'dropout': rng}, - train=train) + logits = self._model.apply( + {'params': params}, + augmented_and_preprocessed_input_batch['inputs'], + rngs={'dropout': rng}, + train=train, + dropout_rate=dropout_rate, + ) return logits, None def _binary_cross_entropy_with_mask( - self, - labels: jnp.ndarray, - logits: jnp.ndarray, - mask: jnp.ndarray, - label_smoothing: float = 0.0) -> jnp.ndarray: + self, + labels: jnp.ndarray, + logits: jnp.ndarray, + mask: jnp.ndarray, + label_smoothing: float = 0.0, + ) -> jnp.ndarray: """Binary cross entropy loss for logits, with masked elements.""" if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens raise ValueError( - f'Shape mismatch between logits ({logits.shape}), targets ' - f'({labels.shape}), and weights ({mask.shape}).') + f'Shape mismatch between logits ({logits.shape}), targets ' + f'({labels.shape}), and weights ({mask.shape}).' + ) if len(logits.shape) != 2: raise ValueError(f'Rank of logits ({logits.shape}) must be 2.') @@ -97,26 +99,31 @@ def _binary_cross_entropy_with_mask( positive_logits = logits >= 0 relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) - losses = relu_logits - (logits * smoothed_labels) + ( - jnp.log(1 + jnp.exp(-abs_logits))) - return jnp.where(mask, losses, 0.) + losses = ( + relu_logits + - (logits * smoothed_labels) + + (jnp.log(1 + jnp.exp(-abs_logits))) + ) + return jnp.where(mask, losses, 0.0) def _eval_metric(self, labels, logits, masks): loss = self.loss_fn(labels, logits, masks) return metrics.EvalMetrics.single_from_model_output( - loss=loss['per_example'], logits=logits, labels=labels, mask=masks) + loss=loss['per_example'], logits=logits, labels=labels, mask=masks + ) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.pmap, + axis_name='batch', + in_axes=(None, 0, 0, 0, None), + static_broadcasted_argnums=(0,), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples total_metrics = total_metrics.reduce() @@ -124,7 +131,6 @@ def _normalize_eval_metrics( class OgbgGeluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -140,7 +146,6 @@ def test_target_value(self) -> float: class OgbgSiluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -156,7 +161,6 @@ def test_target_value(self) -> float: class OgbgModelSizeWorkload(OgbgWorkload): - @property def hidden_dims(self) -> Tuple[int]: return (256, 256) diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py index fe9b29bc1..a69bc6ee1 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -4,22 +4,26 @@ from typing import Callable, Optional, Tuple import jax.tree_util as tree -from jraph import GraphsTuple import torch +from jraph import GraphsTuple from torch import nn from algoperf import init_utils +from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout + +DROPOUT_RATE = 0.1 -def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): +def _make_mlp(in_dim, hidden_dims, activation_fn): """Creates a MLP with specified dimensions.""" - layers = nn.Sequential() + layers = SequentialWithDropout() for i, dim in enumerate(hidden_dims): - layers.add_module(f'dense_{i}', - nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module( + f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim) + ) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) - layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + layers.add_module(f'dropout_{i}', CustomDropout()) in_dim = dim return layers @@ -31,20 +35,19 @@ class GNN(nn.Module): variables. The final prediction will be encoded in the globals. """ - def __init__(self, - num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1, - activation_fn_name: str = 'relu', - latent_dim: int = 256, - hidden_dims: Tuple[int] = (256,), - num_message_passing_steps: int = 5) -> None: + def __init__( + self, + num_outputs: int = 128, + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5, + ) -> None: super().__init__() self.latent_dim = latent_dim self.hidden_dims = hidden_dims self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs - if dropout_rate is None: - dropout_rate = 0.1 # in_features are specifically chosen for the ogbg workload. self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) @@ -57,7 +60,8 @@ def __init__(self, activation_fn = nn.SiLU else: raise ValueError( - f'Invalid activation function name: {self.activation_fn_name}') + f'Invalid activation function name: {self.activation_fn_name}' + ) graph_network_layers = [] for st in range(self.num_message_passing_steps): @@ -65,8 +69,9 @@ def __init__(self, # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[ - -1] * 2 + self.num_outputs + in_dim_node_fn = ( + self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs + ) last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 @@ -74,36 +79,40 @@ def __init__(self, last_in_dim = self.hidden_dims[-1] * 3 graph_network_layers.append( - GraphNetwork( - update_edge_fn=_make_mlp(in_dim_edge_fn, - self.hidden_dims, - dropout_rate, - activation_fn), - update_node_fn=_make_mlp(in_dim_node_fn, - self.hidden_dims, - dropout_rate, - activation_fn), - update_global_fn=_make_mlp(last_in_dim, - self.hidden_dims, - dropout_rate, - activation_fn))) - self.graph_network = nn.Sequential(*graph_network_layers) + GraphNetwork( + update_edge_fn=_make_mlp( + in_dim_edge_fn, self.hidden_dims, activation_fn + ), + update_node_fn=_make_mlp( + in_dim_node_fn, self.hidden_dims, activation_fn + ), + update_global_fn=_make_mlp( + last_in_dim, self.hidden_dims, activation_fn + ), + ) + ) + self.graph_network = SequentialWithDropout(*graph_network_layers) self.decoder = nn.Linear( - in_features=self.hidden_dims[-1], out_features=self.num_outputs) + in_features=self.hidden_dims[-1], out_features=self.num_outputs + ) for m in self.modules(): if isinstance(m, nn.Linear): init_utils.pytorch_default_init(m) - def forward(self, graph: GraphsTuple) -> torch.Tensor: + def forward( + self, graph: GraphsTuple, dropout_rate: float = DROPOUT_RATE + ) -> torch.Tensor: graph = graph._replace( - globals=torch.zeros([graph.n_node.shape[0], self.num_outputs], - device=graph.n_node.device)) + globals=torch.zeros( + [graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device + ) + ) graph = graph._replace(nodes=self.node_embedder(graph.nodes)) graph = graph._replace(edges=self.edge_embedder(graph.edges)) - graph = self.graph_network(graph) + graph = self.graph_network(graph, dropout_rate) # Map globals to represent the final result graph = graph._replace(globals=self.decoder(graph.globals)) @@ -137,16 +146,19 @@ class GraphNetwork(nn.Module): A method that applies the configured GraphNetwork. """ - def __init__(self, - update_edge_fn: Optional[Callable] = None, - update_node_fn: Optional[Callable] = None, - update_global_fn: Optional[Callable] = None) -> None: + def __init__( + self, + update_edge_fn: Optional[Callable] = None, + update_node_fn: Optional[Callable] = None, + update_global_fn: Optional[Callable] = None, + ) -> None: super().__init__() self.update_edge_fn = update_edge_fn self.update_node_fn = update_node_fn self.update_global_fn = update_global_fn + self._supports_custom_dropout = True # supports SequentialWithDropout - def forward(self, graph: GraphsTuple) -> GraphsTuple: + def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple: """Applies a configured GraphNetwork to a graph. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 There is one difference. For the nodes update the class aggregates over the @@ -159,42 +171,49 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: GraphNets, for more information please see the paper. Args: graph: a `GraphsTuple` containing the graph. + dropout_rate: dropout probability value. Returns: Updated `GraphsTuple`. """ nodes, edges, receivers, senders, globals_, n_node, n_edge = graph sum_n_node = tree.tree_leaves(nodes)[0].shape[0] if not tree.tree_all( - tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)): + tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes) + ): raise ValueError( - 'All node arrays in nest must contain the same number of nodes.') + 'All node arrays in nest must contain the same number of nodes.' + ) sent_attributes = tree.tree_map(lambda n: n[senders], nodes) received_attributes = tree.tree_map(lambda n: n[receivers], nodes) # Here we scatter the global features to the corresponding edges, # giving us tensors of shape [num_edges, global_feat]. global_edge_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) + lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_ + ) if self.update_edge_fn: edge_fn_inputs = torch.cat( - [edges, sent_attributes, received_attributes, global_edge_attributes], - dim=-1) - edges = self.update_edge_fn(edge_fn_inputs) + [edges, sent_attributes, received_attributes, global_edge_attributes], + dim=-1, + ) + edges = self.update_edge_fn(edge_fn_inputs, dropout_rate) if self.update_node_fn: sent_attributes = tree.tree_map( - lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges) + lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges + ) received_attributes = tree.tree_map( - lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), - edges) + lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), edges + ) # Here we scatter the global features to the corresponding nodes, # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( - lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) + lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_ + ) node_fn_inputs = torch.cat( - [nodes, sent_attributes, received_attributes, global_attributes], - dim=-1) - nodes = self.update_node_fn(node_fn_inputs) + [nodes, sent_attributes, received_attributes, global_attributes], dim=-1 + ) + nodes = self.update_node_fn(node_fn_inputs, dropout_rate) if self.update_global_fn: n_graph = n_node.shape[0] @@ -207,31 +226,37 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0) # We use the aggregation function to pool the nodes/edges per graph. node_attributes = tree.tree_map( - lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes) + lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes + ) edge_attributes = tree.tree_map( - lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges) + lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges + ) # These pooled nodes are the inputs to the global update fn. - global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_], - dim=-1) - globals_ = self.update_global_fn(global_fn_inputs) + global_fn_inputs = torch.cat( + [node_attributes, edge_attributes, globals_], dim=-1 + ) + globals_ = self.update_global_fn(global_fn_inputs, dropout_rate) return GraphsTuple( - nodes=nodes, - edges=edges, - receivers=receivers, - senders=senders, - globals=globals_, - n_node=n_node, - n_edge=n_edge) + nodes=nodes, + edges=edges, + receivers=receivers, + senders=senders, + globals=globals_, + n_node=n_node, + n_edge=n_edge, + ) # Forked from # github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py. -def scatter_sum(src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None) -> torch.Tensor: +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: r""" | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py index 45295ac7f..f72ff5141 100644 --- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -1,17 +1,17 @@ """OGBG workload implemented in PyTorch.""" + import contextlib from typing import Any, Callable, Dict, Optional, Tuple import jax -from jraph import GraphsTuple import torch import torch.distributed as dist +from jraph import GraphsTuple from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_pytorch import models from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN from algoperf.workloads.ogbg.workload import BaseOgbgWorkload @@ -22,9 +22,11 @@ def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) return jax.tree.map( - lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) - if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1), - inputs) + lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) + if len(a.shape) == 3 + else torch.as_tensor(a, device=DEVICE).view(-1), + inputs, + ) def _shard(inputs: Any) -> Any: @@ -35,44 +37,47 @@ def _shard(inputs: Any) -> Any: def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple: return GraphsTuple( - nodes=function(graph.nodes), - edges=function(graph.edges), - receivers=function(graph.receivers), - senders=function(graph.senders), - globals=function(graph.globals), - n_node=function(graph.n_node), - n_edge=function(graph.n_edge)) + nodes=function(graph.nodes), + edges=function(graph.edges), + receivers=function(graph.receivers), + senders=function(graph.senders), + globals=function(graph.globals), + n_node=function(graph.n_node), + n_edge=function(graph.n_edge), + ) class OgbgWorkload(BaseOgbgWorkload): - # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of valid examples in batch, 'per_example': 1-d array of per-example losses} (not synced across devices). """ - loss_dict = super().loss_fn(label_batch, - logits_batch, - mask_batch, - label_smoothing) + loss_dict = super().loss_fn( + label_batch, logits_batch, mask_batch, label_smoothing + ) loss_dict['n_valid_examples'] = torch.as_tensor( - loss_dict['n_valid_examples'], device=DEVICE) + loss_dict['n_valid_examples'], device=DEVICE + ) return loss_dict - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + ): # TODO: Check where the + 1 comes from. per_device_batch_size = int(global_batch_size / N_GPUS) + 1 @@ -80,10 +85,9 @@ def _build_input_queue(self, # avoid creating too many threads. if RANK == 0: data_rng = data_rng.astype('uint32') - dataset_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size) + dataset_iter = super()._build_input_queue( + data_rng, split, data_dir, global_batch_size + ) while True: if RANK == 0: @@ -91,14 +95,16 @@ def _build_input_queue(self, graph = _graph_map(_pytorch_map, batch['inputs']) targets = torch.as_tensor(batch['targets'], device=DEVICE) weights = torch.as_tensor( - batch['weights'], dtype=torch.bool, device=DEVICE) + batch['weights'], dtype=torch.bool, device=DEVICE + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: dist.broadcast_object_list([graph], src=0, device=DEVICE) # During eval, the batch size of the remainder might be different. if split != 'train': per_device_batch_size = torch.tensor( - len(targets[0]), dtype=torch.int32, device=DEVICE) + len(targets[0]), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) dist.broadcast(targets, src=0) targets = targets[0] @@ -113,44 +119,40 @@ def _build_input_queue(self, graph = graph[0] # During eval, the batch size of the remainder might be different. if split != 'train': - per_device_batch_size = torch.empty((1,), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (1,), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) targets = torch.empty( - (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE) + (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE + ) dist.broadcast(targets, src=0) targets = targets[RANK] weights = torch.empty( - (N_GPUS, per_device_batch_size, self._num_outputs), - dtype=torch.bool, - device=DEVICE) + (N_GPUS, per_device_batch_size, self._num_outputs), + dtype=torch.bool, + device=DEVICE, + ) dist.broadcast(weights, src=0) weights = weights[RANK] batch = { - 'inputs': _graph_map(_shard, graph), - 'targets': targets, - 'weights': weights, + 'inputs': _graph_map(_shard, graph), + 'targets': targets, + 'weights': weights, } yield batch - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is unused.""" - del aux_dropout_rate + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) model = GNN( - num_outputs=self._num_outputs, - dropout_rate=dropout_rate, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps, - activation_fn_name=self.activation_fn_name) + num_outputs=self._num_outputs, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -165,19 +167,22 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['decoder.weight', 'decoder.bias'] def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: """Get predicted logits from the network for input graphs.""" del rng del update_batch_norm # No BN in the GNN model. if model_state is not None: raise ValueError( - f'Expected model_state to be None, received {model_state}.') + f'Expected model_state to be None, received {model_state}.' + ) model = params if mode == spec.ForwardPassMode.TRAIN: @@ -186,26 +191,31 @@ def model_fn( model.eval() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - logits = model(augmented_and_preprocessed_input_batch['inputs']) + logits = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits, None def _binary_cross_entropy_with_mask( - self, - labels: torch.Tensor, - logits: torch.Tensor, - mask: torch.Tensor, - label_smoothing: float = 0.0) -> torch.Tensor: + self, + labels: torch.Tensor, + logits: torch.Tensor, + mask: torch.Tensor, + label_smoothing: float = 0.0, + ) -> torch.Tensor: """Binary cross entropy loss for logits, with masked elements.""" if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens raise ValueError( - f'Shape mismatch between logits ({logits.shape}), targets ' - f'({labels.shape}), and weights ({mask.shape}).') + f'Shape mismatch between logits ({logits.shape}), targets ' + f'({labels.shape}), and weights ({mask.shape}).' + ) if len(logits.shape) != 2: raise ValueError(f'Rank of logits ({logits.shape}) must be 2.') @@ -215,36 +225,40 @@ def _binary_cross_entropy_with_mask( # Apply label_smoothing. num_classes = labels.shape[-1] - smoothed_labels = ((1.0 - label_smoothing) * labels + - label_smoothing / num_classes) + smoothed_labels = ( + 1.0 - label_smoothing + ) * labels + label_smoothing / num_classes # Numerically stable implementation of BCE loss. # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). positive_logits = logits >= 0 relu_logits = torch.where(positive_logits, logits, 0) abs_logits = torch.where(positive_logits, logits, -logits) - losses = relu_logits - (logits * smoothed_labels) + ( - torch.log(1 + torch.exp(-abs_logits))) - return torch.where(mask.to(torch.bool), losses, 0.) + losses = ( + relu_logits + - (logits * smoothed_labels) + + (torch.log(1 + torch.exp(-abs_logits))) + ) + return torch.where(mask.to(torch.bool), losses, 0.0) def _eval_metric(self, labels, logits, masks): loss = self.loss_fn(labels, logits, masks) return metrics.EvalMetrics.single_from_model_output( - loss=loss['per_example'].cpu().numpy(), - logits=logits.cpu().numpy(), - labels=labels.cpu().numpy(), - mask=masks.cpu().numpy()) + loss=loss['per_example'].cpu().numpy(), + logits=logits.cpu().numpy(), + labels=labels.cpu().numpy(), + mask=masks.cpu().numpy(), + ) def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples return {k: float(v) for k, v in total_metrics.compute().items()} class OgbgGeluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -260,7 +274,6 @@ def test_target_value(self) -> float: class OgbgSiluWorkload(OgbgWorkload): - @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" @@ -276,7 +289,6 @@ def test_target_value(self) -> float: class OgbgModelSizeWorkload(OgbgWorkload): - @property def hidden_dims(self) -> Tuple[int]: return (256, 256) diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..8717e46d6 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -9,12 +9,10 @@ from algoperf import random_utils as prng from algoperf import spec -from algoperf.workloads.ogbg import input_pipeline -from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg import input_pipeline, metrics class BaseOgbgWorkload(spec.Workload): - _num_outputs: int = 128 @property @@ -40,8 +38,10 @@ def num_message_passing_steps(self) -> int: return 5 def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result[ - 'validation/mean_average_precision'] > self.validation_target_value + return ( + eval_result['validation/mean_average_precision'] + > self.validation_target_value + ) @property def validation_target_value(self) -> float: @@ -94,15 +94,16 @@ def max_allowed_runtime_sec(self) -> int: def eval_period_time_sec(self) -> int: return 4 * 60 - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int): - dataset_iter = input_pipeline.get_dataset_iter(split, - data_rng, - data_dir, - global_batch_size) + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + ): + dataset_iter = input_pipeline.get_dataset_iter( + split, data_rng, data_dir, global_batch_size + ) if split != 'train': # Note that this stores the entire val dataset in memory. dataset_iter = itertools.cycle(dataset_iter) @@ -111,11 +112,12 @@ def _build_input_queue(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -123,19 +125,20 @@ def loss_fn( (not synced across devices). """ per_example_losses = self._binary_cross_entropy_with_mask( - labels=label_batch, - logits=logits_batch, - mask=mask_batch, - label_smoothing=label_smoothing) + labels=label_batch, + logits=logits_batch, + mask=mask_batch, + label_smoothing=label_smoothing, + ) if mask_batch is not None: n_valid_examples = mask_batch.sum() else: n_valid_examples = len(per_example_losses) summed_loss = per_example_losses.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @property @@ -145,39 +148,45 @@ def step_hint(self) -> int: @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" - def _eval_batch(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState) -> metrics.EvalMetrics: + def _eval_batch( + self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + ) -> metrics.EvalMetrics: logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False, + ) return self._eval_metric(batch['targets'], logits, batch['weights']) - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: self._eval_iters[split] = self._build_input_queue( - data_rng, split, data_dir, global_batch_size=global_batch_size) + data_rng, split, data_dir, global_batch_size=global_batch_size + ) total_metrics = None num_eval_steps = int(math.ceil(float(num_examples) / global_batch_size)) @@ -186,8 +195,10 @@ def _eval_model_on_split(self, batch = next(self._eval_iters[split]) batch_metrics = self._eval_batch(params, batch, model_state, model_rng) total_metrics = ( - batch_metrics - if total_metrics is None else total_metrics.merge(batch_metrics)) + batch_metrics + if total_metrics is None + else total_metrics.merge(batch_metrics) + ) if total_metrics is None: return {} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algoperf/workloads/utils.py b/algoperf/workloads/utils.py index 7719f91fb..920c3cf46 100644 --- a/algoperf/workloads/utils.py +++ b/algoperf/workloads/utils.py @@ -5,10 +5,12 @@ def print_jax_model_summary(model, fake_inputs): """Prints a summary of the jax module.""" tabulate_fn = nn.tabulate( - model, - jax.random.PRNGKey(0), - console_kwargs={ - 'force_terminal': False, 'force_jupyter': False, 'width': 240 - }, + model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, + 'force_jupyter': False, + 'width': 240, + }, ) print(tabulate_fn(fake_inputs, train=False)) diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..6e29b1b83 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -5,19 +5,17 @@ https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. """ -from collections import Counter -from collections import namedtuple -from itertools import zip_longest -import logging import math import re import sys -from typing import List, Sequence import unicodedata +from collections import Counter, namedtuple +from itertools import zip_longest +from typing import List, Sequence -from absl import logging import torch import torch.distributed as dist +from absl import logging from algoperf.pytorch_utils import pytorch_setup @@ -30,11 +28,11 @@ def my_log(num): """ - Floors the log function + Floors the log function - :param num: the number - :return: log(num) floored to a very low number - """ + :param num: the number + :return: log(num) floored to a very low number + """ if num == 0.0: return -9999999999 @@ -43,12 +41,12 @@ def my_log(num): def tokenize_13a(line): """ - Tokenizes an input line using a relatively minimal tokenization that is - however equivalent to mteval-v13a, used by WMT. + Tokenizes an input line using a relatively minimal tokenization that is + however equivalent to mteval-v13a, used by WMT. - :param line: a segment to tokenize - :return: the tokenized line - """ + :param line: a segment to tokenize + :return: the tokenized line + """ norm = line @@ -62,14 +60,17 @@ def tokenize_13a(line): norm = norm.replace('>', '>') # language-dependent part (assuming Western languages): - norm = " {} ".format(norm) + norm = ' {} '.format(norm) norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) - norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', - norm) # tokenize period and comma unless preceded by a digit - norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', - norm) # tokenize period and comma unless followed by a digit - norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', - norm) # tokenize dash when preceded by a digit + norm = re.sub( + r'([^0-9])([\.,])', '\\1 \\2 ', norm + ) # tokenize period and comma unless preceded by a digit + norm = re.sub( + r'([\.,])([^0-9])', ' \\1 \\2', norm + ) # tokenize period and comma unless followed by a digit + norm = re.sub( + r'([0-9])(-)', '\\1 \\2 ', norm + ) # tokenize dash when preceded by a digit norm = re.sub(r'\s+', ' ', norm) # one space only between words norm = re.sub(r'^\s+', '', norm) # no leading space norm = re.sub(r'\s+$', '', norm) # no trailing space @@ -80,14 +81,15 @@ def tokenize_13a(line): class UnicodeRegex: """Ad-hoc hack to recognize all punctuation and symbols. - without depending on https://pypi.python.org/pypi/regex/.""" + without depending on https://pypi.python.org/pypi/regex/.""" @staticmethod def _property_chars(prefix): return ''.join( - chr(x) - for x in range(sys.maxunicode) - if unicodedata.category(chr(x)).startswith(prefix)) + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix) + ) punctuation = _property_chars('P') nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') @@ -98,27 +100,27 @@ def _property_chars(prefix): def tokenize_v14_international(string): r"""Tokenize a string following the official BLEU implementation. - See - https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 - In our case, the input string is expected to be just one line - and no HTML entities de-escaping is needed. - So we just tokenize on punctuation and symbols, - except when a punctuation is preceded and followed by a digit - (e.g. a comma/dot as a thousand/decimal separator). - - Note that a number (e.g., a year) followed by a dot at the end of sentence - is NOT tokenized, - i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` - does not match this case (unless we add a space after each sentence). - However, this error is already in the original mteval-v14.pl - and we want to be consistent with it. - The error is not present in the non-international version, - which uses, - `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). - - :param string: the input string - :return: a list of tokens - """ + See + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence + is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses, + `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) string = UnicodeRegex.symbol_re.sub(r' \1 ', string) @@ -127,94 +129,94 @@ def tokenize_v14_international(string): def tokenize_zh(sentence): """MIT License - Copyright (c) 2017 - Shujian Huang - - Permission is hereby granted, free of charge, to any person obtaining - a copy of this software and associated documentation files - (the "Software"), to deal in the Software without restriction, including - without limitation the rights to use, copy, modify, merge, publish, - distribute, sublicense, and/or sell copies of the Software, and to - permit persons to whom the Software is furnished to do so, subject to the - following conditions: - - The above copyright notice and this permission notice shall be included - in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, - DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR - OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE - USE OR OTHER DEALINGS IN THE SOFTWARE. - - The tokenization of Chinese text in this script contains two steps: - separate each Chinese characters (by utf-8 encoding); - tokenize the non Chinese part (following the mteval script). - Author: Shujian Huang huangsj@nju.edu.cn - - :param sentence: input sentence - :return: tokenized sentence - """ + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: + separate each Chinese characters (by utf-8 encoding); + tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ def is_chinese_char(uchar): """ - :param uchar: input char in unicode - :return: whether the input char is a Chinese character. - """ - if "\u3400" <= uchar <= "\u4db5": + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if '\u3400' <= uchar <= '\u4db5': return True - elif "\u4e00" <= uchar <= "\u9fa5": + elif '\u4e00' <= uchar <= '\u9fa5': return True - elif "\u9fa6" <= uchar <= "\u9fbb": + elif '\u9fa6' <= uchar <= '\u9fbb': return True - elif "\uf900" <= uchar <= "\ufa2d": + elif '\uf900' <= uchar <= '\ufa2d': return True - elif "\ufa30" <= uchar <= "\ufa6a": + elif '\ufa30' <= uchar <= '\ufa6a': return True - elif "\ufa70" <= uchar <= "\ufad9": + elif '\ufa70' <= uchar <= '\ufad9': return True - elif "\u20000" <= uchar <= "\u2a6d6": + elif '\u20000' <= uchar <= '\u2a6d6': return True - elif "\u2f800" <= uchar <= "\u2fa1d": + elif '\u2f800' <= uchar <= '\u2fa1d': return True - elif "\uff00" <= uchar <= "\uffef": + elif '\uff00' <= uchar <= '\uffef': return True - elif "\u2e80" <= uchar <= "\u2eff": + elif '\u2e80' <= uchar <= '\u2eff': return True - elif "\u3000" <= uchar <= "\u303f": + elif '\u3000' <= uchar <= '\u303f': return True - elif "\u31c0" <= uchar <= "\u31ef": + elif '\u31c0' <= uchar <= '\u31ef': return True - elif "\u2f00" <= uchar <= "\u2fdf": + elif '\u2f00' <= uchar <= '\u2fdf': return True - elif "\u2ff0" <= uchar <= "\u2fff": + elif '\u2ff0' <= uchar <= '\u2fff': return True - elif "\u3100" <= uchar <= "\u312f": + elif '\u3100' <= uchar <= '\u312f': return True - elif "\u31a0" <= uchar <= "\u31bf": + elif '\u31a0' <= uchar <= '\u31bf': return True - elif "\ufe10" <= uchar <= "\ufe1f": + elif '\ufe10' <= uchar <= '\ufe1f': return True - elif "\ufe30" <= uchar <= "\ufe4f": + elif '\ufe30' <= uchar <= '\ufe4f': return True - elif "\u2600" <= uchar <= "\u26ff": + elif '\u2600' <= uchar <= '\u26ff': return True - elif "\u2700" <= uchar <= "\u27bf": + elif '\u2700' <= uchar <= '\u27bf': return True - elif "\u3200" <= uchar <= "\u32ff": + elif '\u3200' <= uchar <= '\u32ff': return True - elif "\u3300" <= uchar <= "\u33ff": + elif '\u3300' <= uchar <= '\u33ff': return True return False sentence = sentence.strip() - sentence_in_chars = "" + sentence_in_chars = '' for char in sentence: if is_chinese_char(char): - sentence_in_chars += " " + sentence_in_chars += ' ' sentence_in_chars += char - sentence_in_chars += " " + sentence_in_chars += ' ' else: sentence_in_chars += char sentence = sentence_in_chars @@ -245,10 +247,10 @@ def is_chinese_char(uchar): TOKENIZERS = { - '13a': tokenize_13a, - 'intl': tokenize_v14_international, - 'zh': tokenize_zh, - 'none': lambda x: x, + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, } DEFAULT_TOKENIZER = '13a' @@ -256,16 +258,16 @@ def is_chinese_char(uchar): def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. - :param line: a segment containing a sequence of words - :param max_order: collect n-grams from 1<=n<=max - :return: a dictionary containing ngrams and counts - """ + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ ngrams = Counter() tokens = line.split() for n in range(min_order, max_order + 1): for i in range(0, len(tokens) - n + 1): - ngram = ' '.join(tokens[i:i + n]) + ngram = ' '.join(tokens[i : i + n]) ngrams[ngram] += 1 return ngrams @@ -293,41 +295,44 @@ def ref_stats(output, refs): return ngrams, closest_diff, closest_len -BLEU = namedtuple('BLE', - 'score, counts, totals, precisions, bp, sys_len, ref_len') +BLEU = namedtuple( + 'BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len' +) -def compute_bleu(correct: List[int], - total: List[int], - sys_len: int, - ref_len: int, - smooth_method='none', - smooth_value=SMOOTH_VALUE_DEFAULT, - use_effective_order=False) -> BLEU: +def compute_bleu( + correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False, +) -> BLEU: """Computes BLEU score from its sufficient statistics. Adds smoothing. - Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques - for Sentence-Level BLEU", Boxing Chen and Colin Cherry, - WMT 2014: http://aclweb.org/anthology/W14-3346) - - - exp: NIST smoothing method (Method 3) - - floor: Method 1 - - add-k: Method 2 (generalizing Lin and Och, 2004) - - none: do nothing. - - :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER - :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER - :param sys_len: The cumulative system length - :param ref_len: The cumulative reference length - :param smooth: The smoothing method to use - :param smooth_value: The smoothing value added, if smooth is 'floor' - :param use_effective_order: Use effective order. - :return: A BLEU object with the score (100-based) and other statistics. - """ + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques + for Sentence-Level BLEU", Boxing Chen and Colin Cherry, + WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth is 'floor' + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ precisions = [0 for x in range(NGRAM_ORDER)] - smooth_mteval = 1. + smooth_mteval = 1.0 effective_order = NGRAM_ORDER for n in range(NGRAM_ORDER): if smooth_method == 'add-k' and n > 1: @@ -342,11 +347,11 @@ def compute_bleu(correct: List[int], if correct[n] == 0: if smooth_method == 'exp': smooth_mteval *= 2 - precisions[n] = 100. / (smooth_mteval * total[n]) + precisions[n] = 100.0 / (smooth_mteval * total[n]) elif smooth_method == 'floor': - precisions[n] = 100. * smooth_value / total[n] + precisions[n] = 100.0 * smooth_value / total[n] else: - precisions[n] = 100. * correct[n] / total[n] + precisions[n] = 100.0 * correct[n] / total[n] # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU # score is 0 (technically undefined). This is a problem for sentence-level @@ -360,20 +365,24 @@ def compute_bleu(correct: List[int], brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 bleu = brevity_penalty * math.exp( - sum(map(my_log, precisions[:effective_order])) / effective_order) + sum(map(my_log, precisions[:effective_order])) / effective_order + ) return BLEU._make( - [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) - - -def corpus_bleu(sys_stream: Sequence[str], - ref_streams: Sequence[str], - smooth_method: str = 'exp', - smooth_value: float = 0.0, - force: bool = False, - lowercase: bool = False, - tokenize: str = '13a', - use_effective_order: bool = False) -> BLEU: + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len] + ) + + +def corpus_bleu( + sys_stream: Sequence[str], + ref_streams: Sequence[str], + smooth_method: str = 'exp', + smooth_value: float = 0.0, + force: bool = False, + lowercase: bool = False, + tokenize: str = '13a', + use_effective_order: bool = False, +) -> BLEU: """Produces BLEU scores along with its sufficient statistics from a source against one or more references. :param sys_stream: The system stream (a sequence of segments). @@ -414,13 +423,16 @@ def corpus_bleu(sys_stream: Sequence[str], tokenized_count += 1 if tokenized_count == 100: + logging.warning("That's 100 lines that end in a tokenized period ('.')") + logging.warning( + 'It looks like you forgot to detokenize your test ' + 'data, which may hurt your score.' + ) logging.warning( - 'That\'s 100 lines that end in a tokenized period (\'.\')') - logging.warning('It looks like you forgot to detokenize your test ' - 'data, which may hurt your score.') - logging.warning('If you insist your data is detokenized, ' - 'or don\'t care, you can suppress this message with ' - '\'--force\'.') + 'If you insist your data is detokenized, ' + "or don't care, you can suppress this message with " + "'--force'." + ) output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] @@ -453,10 +465,11 @@ def corpus_bleu(sys_stream: Sequence[str], total = total.cpu().numpy().tolist() return compute_bleu( - correct, - total, - sys_len, - ref_len, - smooth_method=smooth_method, - smooth_value=smooth_value, - use_effective_order=use_effective_order) + correct, + total, + sys_len, + ref_len, + smooth_method=smooth_method, + smooth_value=smooth_value, + use_effective_order=use_effective_order, + ) diff --git a/algoperf/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py index d743b43b0..3d184cd78 100644 --- a/algoperf/workloads/wmt/input_pipeline.py +++ b/algoperf/workloads/wmt/input_pipeline.py @@ -1,4 +1,5 @@ """Input pipeline for a WMT dataset.""" + import functools import os from typing import Dict, List, Optional, Union @@ -16,10 +17,10 @@ Features = Dict[str, tf.Tensor] TFDS_SPLIT_NAME = { - 'train': 'train', - 'eval_train': 'train', - 'validation': 'validation', - 'test': 'test', + 'train': 'train', + 'eval_train': 'train', + 'validation': 'validation', + 'test': 'test', } @@ -31,9 +32,11 @@ def normalize_feature_names(ds_info, features: Features) -> Features: return features -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate @@ -75,7 +78,8 @@ def pack_dataset(dataset: tf.data.Dataset, for k in keys: if k not in shapes: raise ValueError( - f'Key {k} not found in dataset. Available keys are {shapes.keys()}') + f'Key {k} not found in dataset. Available keys are {shapes.keys()}' + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the @@ -88,13 +92,15 @@ def pack_dataset(dataset: tf.data.Dataset, # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}) + batch_size, padded_shapes={k: [-1] for k in keys} + ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -104,9 +110,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, - keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: @@ -127,8 +133,9 @@ def write_packed_example(partial, outputs): new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( - outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + outputs[k].size(), + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -146,9 +153,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -163,13 +172,15 @@ def body_fn(i, partial, outputs): one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( - can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + can_append, + tf.less_equal( + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -180,53 +191,55 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], tf.range(new_seq_len)], 0) + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. i, partial, outputs = tf.while_loop( - cond=lambda *_: True, - body=body_fn, - loop_vars=(i, partial, outputs), - shape_invariants=( - tf.TensorShape([]), - {k: tf.TensorShape([None]) for k in keys_etc}, - {k: tf.TensorShape(None) for k in keys_etc}, - ), - maximum_iterations=dynamic_batch_size) + cond=lambda *_: True, + body=body_fn, + loop_vars=(i, partial, outputs), + shape_invariants=( + tf.TensorShape([]), + {k: tf.TensorShape([None]) for k in keys_etc}, + {k: tf.TensorShape(None) for k in keys_etc}, + ), + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch() -def preprocess_wmt_data(dataset: tf.data.Dataset, - data_rng, - train: bool, - shuffle: bool, - shuffle_buffer_size: int = 1024, - max_length: int = 256, - global_batch_size: int = 128): +def preprocess_wmt_data( + dataset: tf.data.Dataset, + data_rng, + train: bool, + shuffle: bool, + shuffle_buffer_size: int = 1024, + max_length: int = 256, + global_batch_size: int = 128, +): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): - def filter_fn(x): source, target = x['inputs'], x['targets'] - l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) - return tf.less(l, max_len + 1) + length = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) + return tf.less(length, max_len + 1) return filter_fn @@ -242,24 +255,27 @@ def filter_fn(x): dataset = dataset.batch(global_batch_size, drop_remainder=train) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( - global_batch_size, - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=False) + global_batch_size, + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=False, + ) dataset = dataset.prefetch(AUTOTUNE) return dataset -def get_wmt_dataset(data_rng, - split: str, - data_dir: str, - is_training: bool, - vocab_size: int, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False, - vocab_path: Optional[str] = None): +def get_wmt_dataset( + data_rng, + split: str, + data_dir: str, + is_training: bool, + vocab_size: int, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + vocab_path: Optional[str] = None, +): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model') @@ -271,7 +287,8 @@ def get_wmt_dataset(data_rng, dataset_builder = tfds.builder(ds_name, data_dir=data_dir) ds = dataset_builder.as_dataset( - split=TFDS_SPLIT_NAME[split], shuffle_files=False) + split=TFDS_SPLIT_NAME[split], shuffle_files=False + ) # Avoid creating too many threads when using PyTorch DDP. if RANK != 0: @@ -280,8 +297,9 @@ def get_wmt_dataset(data_rng, ds = ds.with_options(options) ds = ds.map( - functools.partial(normalize_feature_names, dataset_builder.info), - num_parallel_calls=AUTOTUNE) + functools.partial(normalize_feature_names, dataset_builder.info), + num_parallel_calls=AUTOTUNE, + ) # Load tf-text SentencePiece tokenizer. sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path) @@ -289,12 +307,13 @@ def get_wmt_dataset(data_rng, shuffle = split in ['train', 'eval_train'] ds = preprocess_wmt_data( - ds, - data_rng, - train=is_training, - shuffle=shuffle, - global_batch_size=global_batch_size, - max_length=256) + ds, + data_rng, + train=is_training, + shuffle=shuffle, + global_batch_size=global_batch_size, + max_length=256, + ) if num_batches: ds = ds.take(num_batches) @@ -303,9 +322,10 @@ def get_wmt_dataset(data_rng, ds = ds.repeat() ds = map( - functools.partial( - data_utils.shard_and_maybe_pad_np, - global_batch_size=global_batch_size), - ds) + functools.partial( + data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size + ), + ds, + ) return ds, sp_tokenizer diff --git a/algoperf/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py index 1f001e619..273e11dfa 100644 --- a/algoperf/workloads/wmt/tokenizer.py +++ b/algoperf/workloads/wmt/tokenizer.py @@ -9,19 +9,19 @@ import time from typing import Any, Dict, Iterable, Tuple -from absl import logging import jax -from sentencepiece import SentencePieceTrainer import tensorflow as tf import tensorflow_text as tftxt +from absl import logging +from sentencepiece import SentencePieceTrainer Features = Dict[str, tf.Tensor] def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('inputs', 'targets') + dataset: tf.data.Dataset, + maxchars: int = int(1e7), + data_keys=('inputs', 'targets'), ) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -36,7 +36,8 @@ def _dump_chars_to_textfile( char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + delete=False, prefix='/tmp/ds_chars' + ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: @@ -46,14 +47,16 @@ def _dump_chars_to_textfile( return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('inputs', 'targets')): +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -75,17 +78,21 @@ def _train_sentencepiece(dataset: tf.data.Dataset, else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) + dataset, maxchars=maxchars, data_keys=data_keys + ) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + delete=False, prefix='/tmp/sp_tmp' + ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ + argstr = ' '.join( + [ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', - ]) + ] + ) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address @@ -100,32 +107,38 @@ def _train_sentencepiece(dataset: tf.data.Dataset, time.sleep(1) -def _load_sentencepiece_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer( + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer -def train_tokenizer(dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str, str] = ('inputs', 'targets')): +def train_tokenizer( + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str, str] = ('inputs', 'targets'), +): """Trains a tokenizer from `dataset`.""" logging.info('Building SentencePiece vocab from data.') _train_sentencepiece( - dataset, - vocab_size=vocab_size, - maxchars=max_corpus_chars, - model_path=vocab_path, - data_keys=data_keys) + dataset, + vocab_size=vocab_size, + maxchars=max_corpus_chars, + model_path=vocab_path, + data_keys=data_keys, + ) def load_tokenizer(vocab_path: str): @@ -135,7 +148,6 @@ def load_tokenizer(vocab_path: str): @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') diff --git a/algoperf/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py index dfead5918..196d9175e 100644 --- a/algoperf/workloads/wmt/wmt_jax/decode.py +++ b/algoperf/workloads/wmt/wmt_jax/decode.py @@ -7,9 +7,9 @@ import flax import jax -from jax import lax import jax.numpy as jnp import numpy as np +from jax import lax # Constants # We assume the default End-of-Sentence token id is 2 (SentencePiece). @@ -78,8 +78,9 @@ def gather_beams(nested, beam_indices, batch_size, new_beam_size): [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] """ batch_indices = jnp.reshape( - jnp.arange(batch_size * new_beam_size) // new_beam_size, - (batch_size, new_beam_size)) + jnp.arange(batch_size * new_beam_size) // new_beam_size, + (batch_size, new_beam_size), + ) def gather_fn(x): if x.ndim < 2: # ignore scalars (e.g. cache index) @@ -114,6 +115,7 @@ def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @flax.struct.dataclass class BeamState: """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. cur_index: jax.Array # scalar int32: current decoded length index # The active sequence log probabilities and finished sequence scores. @@ -133,7 +135,8 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( - jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) + jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] + ) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) @@ -141,25 +144,28 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( - cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0) + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) # Beam search routine: -def beam_search(inputs, - cache, - tokens_to_logits, - beam_size=4, - alpha=0.6, - eos_id=EOS_ID, - max_decode_len=None): +def beam_search( + inputs, + cache, + tokens_to_logits, + beam_size=4, + alpha=0.6, + eos_id=EOS_ID, + max_decode_len=None, +): """Beam search for transformer machine translation. Args: @@ -185,10 +191,9 @@ def beam_search(inputs, end_marker = jnp.array(eos_id) # initialize beam search state - beam_search_init_state = beam_init(batch_size, - beam_size, - max_decode_len, - cache) + beam_search_init_state = beam_init( + batch_size, beam_size, max_decode_len, cache + ) def beam_search_loop_cond_fn(state): """Beam search loop termination condition.""" @@ -201,11 +206,12 @@ def beam_search_loop_cond_fn(state): best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. worst_finished_scores = jnp.min( - state.finished_scores, axis=1, keepdims=True) + state.finished_scores, axis=1, keepdims=True + ) # Mask out scores from slots without any actual finished sequences. - worst_finished_scores = jnp.where(state.finished_flags, - worst_finished_scores, - NEG_INF) + worst_finished_scores = jnp.where( + state.finished_flags, worst_finished_scores, NEG_INF + ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = jnp.all(worst_finished_scores > best_live_scores) @@ -221,8 +227,10 @@ def beam_search_loop_body_fn(state): # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( - lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), - (batch_size, beam_size, 1))) + lax.dynamic_slice( + state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) + ) + ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree.map(flatten_beam_dim, state.cache) @@ -237,14 +245,16 @@ def beam_search_loop_body_fn(state): # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree.map( - lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) + lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache + ) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] - log_probs = ( - candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) + log_probs = candidate_log_probs + jnp.expand_dims( + state.live_logprobs, axis=2 + ) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] @@ -264,10 +274,9 @@ def beam_search_loop_body_fn(state): topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] - topk_seq = gather_beams(state.live_seqs, - topk_beam_indices, - batch_size, - beams_to_keep) + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. @@ -275,13 +284,14 @@ def beam_search_loop_body_fn(state): topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] - topk_seq = lax.dynamic_update_slice(topk_seq, - topk_ids, (0, 0, state.cur_index + 1)) + topk_seq = lax.dynamic_update_slice( + topk_seq, topk_ids, (0, 0, state.cur_index + 1) + ) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] - newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -292,22 +302,20 @@ def beam_search_loop_body_fn(state): new_topk_indices = jnp.flip(new_topk_indices, axis=1) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] - top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], - new_topk_indices, - batch_size, beam_size) + top_alive_seq, top_alive_log_probs = gather_beams( + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] - top_alive_indices = gather_beams(topk_beam_indices, - new_topk_indices, - batch_size, - beam_size) + top_alive_indices = gather_beams( + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} - top_alive_cache = gather_beams(new_cache, - top_alive_indices, - batch_size, - beam_size) + top_alive_cache = gather_beams( + new_cache, top_alive_indices, batch_size, beam_size + ) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. @@ -320,42 +328,54 @@ def beam_search_loop_body_fn(state): # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], - axis=1) + [state.finished_seqs, topk_seq], axis=1 + ) finished_scores = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_scores, new_scores], axis=1) + [state.finished_scores, new_scores], axis=1 + ) finished_flags = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], axis=1) + [state.finished_flags, newly_finished], axis=1 + ) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams([finished_seqs, finished_scores, finished_flags], - finished_scores, batch_size, beam_size)) + gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + ) return BeamState( - cur_index=state.cur_index + 1, - live_logprobs=top_alive_log_probs, - finished_scores=top_finished_scores, - live_seqs=top_alive_seq, - finished_seqs=top_finished_seq, - finished_flags=top_finished_flags, - cache=top_alive_cache) + cur_index=state.cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache, + ) # Run while loop and get final beam search state. - final_state = lax.while_loop(beam_search_loop_cond_fn, - beam_search_loop_body_fn, - beam_search_init_state) + final_state = lax.while_loop( + beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state + ) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return live sequences for that batch item. # --> [batch] none_finished = jnp.any(final_state.finished_flags, axis=1) # --> [batch, beams, length] - finished_seqs = jnp.where(none_finished[:, None, None], - final_state.finished_seqs, - final_state.live_seqs) + finished_seqs = jnp.where( + none_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) # --> [batch, beams] - finished_scores = jnp.where(none_finished[:, None], - final_state.finished_scores, - final_state.live_logprobs) + finished_scores = jnp.where( + none_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) return finished_seqs, finished_scores diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index 97fee032f..81f2ece4c 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -5,16 +5,21 @@ from typing import Any, Callable, Optional +import jax.numpy as jnp +import numpy as np from flax import linen as nn from flax import struct from jax import lax -import jax.numpy as jnp -import numpy as np + +from algoperf.jax_utils import Dropout + +DROPOUT_RATE = 0.1 @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + share_embeddings: bool = True dtype: Any = jnp.float32 vocab_size: int = 32000 @@ -26,10 +31,6 @@ class TransformerConfig: max_len: int = 256 activation: Callable = nn.relu glu: bool = False - #If None, defaults to 0.1. - dropout_rate: Optional[float] = 0.1 - #If None, defaults to 0.1. - attention_dropout_rate: Optional[float] = 0.1 attention_temp: float = 1.0 deterministic: bool = False decode: bool = False @@ -44,7 +45,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return padded[:, :-1] @@ -68,8 +70,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -83,6 +85,7 @@ class AddPositionEmbs(nn.Module): config: TransformerConfig dataclass containing hyperparameters. decode: whether to run in single-position autoregressive mode. """ + config: TransformerConfig decode: bool = False @@ -103,27 +106,28 @@ def __call__(self, inputs, inputs_positions=None): """ cfg = self.config # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - f' but it is: {inputs.ndim}') + assert inputs.ndim == 3, ( + f'Number of dimensions should be 3, but it is: {inputs.ndim}' + ) length = inputs.shape[1] pos_emb_shape = (1, cfg.max_len, inputs.shape[-1]) if cfg.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None, - pos_emb_shape, - None) + pos_embedding = sinusoidal_init(max_len=cfg.max_len)( + None, pos_emb_shape, None + ) else: - pos_embedding = self.param('pos_embedding', - cfg.posemb_init, - pos_emb_shape) + pos_embedding = self.param( + 'pos_embedding', cfg.posemb_init, pos_emb_shape + ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', - 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -144,43 +148,43 @@ class MlpBlock(nn.Module): config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ + config: TransformerConfig out_dim: Optional[int] = None @nn.compact - def __call__(self, inputs): + def __call__(self, inputs, dropout_rate=DROPOUT_RATE): """Applies Transformer MlpBlock module.""" cfg = self.config - actual_out_dim = ( - inputs.shape[-1] if self.out_dim is None else self.out_dim) + + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) + cfg.mlp_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(inputs) x = cfg.activation(x) if cfg.glu: y = nn.Dense( - cfg.mlp_dim, - dtype=cfg.dtype, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - inputs) - x = x * y - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) - output = nn.Dense( - actual_out_dim, + cfg.mlp_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init)( - x) - output = nn.Dropout(rate=dropout_rate)( - output, deterministic=cfg.deterministic) + bias_init=cfg.bias_init, + )(inputs) + x = x * y + x = Dropout(rate=dropout_rate)( + x, rate=dropout_rate, deterministic=cfg.deterministic + ) + output = nn.Dense( + actual_out_dim, + dtype=cfg.dtype, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + )(x) + output = Dropout(rate=dropout_rate)( + output, rate=dropout_rate, deterministic=cfg.deterministic + ) return output @@ -190,10 +194,11 @@ class Encoder1DBlock(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact - def __call__(self, inputs, encoder_mask=None): + def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): """Applies Encoder1DBlock module. Args: @@ -204,39 +209,34 @@ def __call__(self, inputs, encoder_mask=None): output after transformer encoder block. """ cfg = self.config + pre_ln = cfg.pre_ln # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * x, x, mask=encoder_mask) - - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * x, x, mask=encoder_mask) + + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate + ) x = x + inputs if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x - y = MlpBlock(config=cfg)(y) + y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate) return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y) @@ -247,14 +247,18 @@ class EncoderDecoder1DBlock(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=DROPOUT_RATE, + ): """Applies EncoderDecoder1DBlock module. Args: @@ -267,33 +271,29 @@ def __call__(self, output after transformer encoder-decoder block. """ cfg = self.config + pre_ln = cfg.pre_ln # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets - if cfg.attention_dropout_rate is None: - attention_dropout_rate = 0.1 - else: - attention_dropout_rate = cfg.attention_dropout_rate x = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic, - decode=cfg.decode)( - cfg.attention_temp * x, x, mask=decoder_mask) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + decode=cfg.decode, + )(cfg.attention_temp * x, x, mask=decoder_mask) + + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate + ) x = x + targets if not pre_ln: x = nn.LayerNorm(dtype=cfg.dtype)(x) @@ -301,25 +301,27 @@ def __call__(self, # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x y = nn.MultiHeadDotProductAttention( - num_heads=cfg.num_heads, - dtype=cfg.dtype, - qkv_features=cfg.qkv_dim, - kernel_init=cfg.kernel_init, - bias_init=cfg.bias_init, - use_bias=False, - broadcast_dropout=False, - dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)( - cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) - - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + num_heads=cfg.num_heads, + dtype=cfg.dtype, + qkv_features=cfg.qkv_dim, + kernel_init=cfg.kernel_init, + bias_init=cfg.bias_init, + use_bias=False, + broadcast_dropout=False, + dropout_rate=dropout_rate, + deterministic=cfg.deterministic, + )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) + + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate + ) y = y + x if not pre_ln: y = nn.LayerNorm(dtype=cfg.dtype)(y) # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y - z = MlpBlock(config=cfg)(z) + z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate) return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z) @@ -331,11 +333,18 @@ class Encoder(nn.Module): config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, inputs, inputs_positions=None, encoder_mask=None): + def __call__( + self, + inputs, + inputs_positions=None, + encoder_mask=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer model on the inputs. Args: @@ -347,37 +356,40 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None): output of a transformer encoder. """ cfg = self.config + assert inputs.ndim == 2 # (batch, len) # Input Embedding if self.shared_embedding is None: input_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) - x = AddPositionEmbs( - config=cfg, decode=False, name='posembed_input')( - x, inputs_positions=inputs_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic) + x = AddPositionEmbs(config=cfg, decode=False, name='posembed_input')( + x, inputs_positions=inputs_positions + ) + x = Dropout(rate=dropout_rate)( + x, deterministic=cfg.deterministic, rate=dropout_rate + ) x = x.astype(cfg.dtype) # Input Encoder for lyr in range(cfg.num_layers): - x = Encoder1DBlock( - config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) + x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')( + x, encoder_mask, dropout_rate + ) encoded = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) - if cfg.pre_ln else x) + nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x) + if cfg.pre_ln + else x + ) return encoded @@ -389,16 +401,20 @@ class Decoder(nn.Module): config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ + config: TransformerConfig shared_embedding: Any = None @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None): + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer model on the inputs. Args: @@ -419,9 +435,10 @@ def __call__(self, # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding @@ -429,28 +446,29 @@ def __call__(self, if not cfg.decode: y = shift_right(y) y = output_embed(y) - y = AddPositionEmbs( - config=cfg, decode=cfg.decode, name='posembed_output')( - y, inputs_positions=targets_positions) - if cfg.dropout_rate is None: - dropout_rate = 0.1 - else: - dropout_rate = cfg.dropout_rate - y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) + y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')( + y, inputs_positions=targets_positions + ) + y = Dropout(rate=dropout_rate)( + y, deterministic=cfg.deterministic, rate=dropout_rate + ) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): - y = EncoderDecoder1DBlock( - config=cfg, name=f'encoderdecoderblock_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + y = EncoderDecoder1DBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) y = ( - nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) - if cfg.pre_ln else y) + nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y) + if cfg.pre_ln + else y + ) # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) @@ -465,6 +483,7 @@ class Transformer(nn.Module): Attributes: config: TransformerConfig dataclass containing hyperparameters. """ + config: TransformerConfig def setup(self): @@ -473,18 +492,26 @@ def setup(self): if cfg.share_embeddings: if cfg.vocab_size is not None: assert cfg.vocab_size == cfg.vocab_size, ( - "can't share embedding with different vocab sizes.") + "can't share embedding with different vocab sizes." + ) self.shared_embedding = nn.Embed( - num_embeddings=cfg.vocab_size, - features=cfg.emb_dim, - embedding_init=nn.initializers.normal(stddev=1.0)) + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: self.shared_embedding = None self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding) - def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): + def encode( + self, + inputs, + inputs_positions=None, + inputs_segmentation=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer encoder-branch on the inputs. Args: @@ -498,27 +525,33 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): cfg = self.config # Make padding attention mask. encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=cfg.dtype) + inputs > 0, inputs > 0, dtype=cfg.dtype + ) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( - encoder_mask, - nn.make_attention_mask( - inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=cfg.dtype)) + encoder_mask, + nn.make_attention_mask( + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype + ), + ) return self.encoder( - inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask) + inputs, + inputs_positions=inputs_positions, + encoder_mask=encoder_mask, + dropout_rate=dropout_rate, + ) def decode( - self, - encoded, - inputs, # only needed for masks - targets, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None): + self, + encoded, + inputs, # only needed for masks + targets, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer decoder-branch on encoded-input and target. Args: @@ -538,45 +571,51 @@ def decode( if cfg.decode: decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype) + jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype + ) else: decoder_mask = nn.combine_masks( - nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), - nn.make_causal_mask(targets, dtype=cfg.dtype)) + nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), + nn.make_causal_mask(targets, dtype=cfg.dtype), + ) encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=cfg.dtype) + targets > 0, inputs > 0, dtype=cfg.dtype + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( - decoder_mask, - nn.make_attention_mask( - targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=cfg.dtype)) + decoder_mask, + nn.make_attention_mask( + targets_segmentation, targets_segmentation, jnp.equal, dtype=cfg.dtype + ), + ) encoder_decoder_mask = nn.combine_masks( - encoder_decoder_mask, - nn.make_attention_mask( - targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=cfg.dtype)) + encoder_decoder_mask, + nn.make_attention_mask( + targets_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype + ), + ) logits = self.decoder( - encoded, - targets, - targets_positions=targets_positions, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask) + encoded, + targets, + targets_positions=targets_positions, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + dropout_rate=dropout_rate, + ) return logits.astype(self.config.dtype) - def __call__(self, - inputs, - targets, - inputs_positions=None, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None): + def __call__( + self, + inputs, + targets, + inputs_positions=None, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + dropout_rate=DROPOUT_RATE, + ): """Applies Transformer model on the inputs. Args: @@ -591,14 +630,18 @@ def __call__(self, logits array from full transformer. """ encoded = self.encode( - inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + ) return self.decode( - encoded, - inputs, # only used for masks - targets, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation) + encoded, + inputs, # only used for masks + targets, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + dropout_rate=dropout_rate, + ) diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..51d8a85a7 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -1,23 +1,21 @@ """WMT workload implemented in Jax.""" -from dataclasses import replace import functools +from dataclasses import replace from typing import Any, Dict, Iterator, Optional, Tuple -from absl import logging -from flax import jax_utils -from flax import linen as nn -from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np import optax +from absl import logging +from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.workloads.wmt import bleu -from algoperf.workloads.wmt.wmt_jax import decode -from algoperf.workloads.wmt.wmt_jax import models +from algoperf.workloads.wmt.wmt_jax import decode, models from algoperf.workloads.wmt.workload import BaseWmtWorkload @@ -31,11 +29,12 @@ class WmtWorkload(BaseWmtWorkload): """WMT Jax workload.""" def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: @@ -50,76 +49,86 @@ def compute_weighted_cross_entropy( valid examples in batch, 'per_example': 1-d array of per-example losses} """ if logits.ndim != targets.ndim + 1: - raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.') + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) smoothed_targets = optax.smooth_labels( - common_utils.onehot(targets, self._vocab_size), label_smoothing) + common_utils.onehot(targets, self._vocab_size), label_smoothing + ) per_example_losses = -jnp.sum( - smoothed_targets * nn.log_softmax(logits), axis=-1) + smoothed_targets * nn.log_softmax(logits), axis=-1 + ) if weights is None: weights = jnp.ones_like(targets) - per_example_losses = jnp.where(weights, per_example_losses, 0.) + per_example_losses = jnp.where(weights, per_example_losses, 0.0) summed_loss = per_example_losses.sum() n_valid_examples = weights.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,) + ) def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] weights = batch['weights'] logits = self._eval_model.apply({'params': params}, inputs, targets) - summed_loss = self.compute_weighted_cross_entropy(logits, - targets, - weights, - 0.0)['summed'] + summed_loss = self.compute_weighted_cross_entropy( + logits, targets, weights, 0.0 + )['summed'] acc_sum, weight_sum = self.compute_weighted_accuracy( - logits, targets, weights) + logits, targets, weights + ) return { - 'loss': summed_loss, - 'accuracy': acc_sum, - 'denominator': weight_sum, + 'loss': summed_loss, + 'accuracy': acc_sum, + 'denominator': weight_sum, } - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + def eval_step( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def initialize_cache(self, - inputs: spec.Tensor, - max_decode_len: int = 256) -> Dict[str, spec.Tensor]: + jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,) + ) + def initialize_cache( + self, inputs: spec.Tensor, max_decode_len: int = 256 + ) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + jax.random.PRNGKey(0), + jnp.ones(inputs.shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + ) return initial_variables['cache'] # eos_id, max_decode_len are constant. @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) - def predict_step(self, - inputs: spec.Tensor, - params: spec.ParameterContainer, - cache: Dict[str, spec.Tensor], - eos_id: int, - max_decode_len: int, - beam_size: int = 4) -> spec.Tensor: + jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5) + ) + def predict_step( + self, + inputs: spec.Tensor, + params: spec.ParameterContainer, + cache: Dict[str, spec.Tensor], + eos_id: int, + max_decode_len: int, + beam_size: int = 4, + ) -> spec.Tensor: """Predict translation with fast decoding beam search on a batch.""" config = replace(self._eval_model.config, decode=True) # Prepare transformer fast-decoder call for beam search: for beam search, we @@ -129,27 +138,29 @@ def predict_step(self, # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( - models.Transformer(config).apply({'params': params}, - inputs, - method=models.Transformer.encode), - beam_size) + models.Transformer(config).apply( + {'params': params}, inputs, method=models.Transformer.encode + ), + beam_size, + ) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) def tokens_ids_to_logits( - flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] + flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] ) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]: """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( - { - 'params': params, - 'cache': flat_cache, - }, - encoded_inputs, - raw_inputs, # only needed for input padding mask - flat_ids, - mutable=['cache'], - method=models.Transformer.decode) + { + 'params': params, + 'cache': flat_cache, + }, + encoded_inputs, + raw_inputs, # only needed for input padding mask + flat_ids, + mutable=['cache'], + method=models.Transformer.decode, + ) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] @@ -159,35 +170,36 @@ def tokens_ids_to_logits( # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( - inputs, - cache, - tokens_ids_to_logits, - beam_size=beam_size, - alpha=0.6, - eos_id=eos_id, - max_decode_len=max_decode_len) + inputs, + cache, + tokens_ids_to_logits, + beam_size=beam_size, + alpha=0.6, + eos_id=eos_id, + max_decode_len=max_decode_len, + ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:] - def translate_and_calculate_bleu(self, - params: spec.ParameterContainer, - ds_iter: Iterator, - num_batches: int, - max_predict_length: int) -> spec.Tensor: + def translate_and_calculate_bleu( + self, + params: spec.ParameterContainer, + ds_iter: Iterator, + num_batches: int, + max_predict_length: int, + ) -> spec.Tensor: """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) + predicted = self.predict_step( + pred_batch['inputs'], params, cache, decode.EOS_ID, max_predict_length + ) predicted = _to_host(predicted) targets = _to_host(pred_batch['targets']) # Find actual batch size, ignoring the potential padding. @@ -206,13 +218,7 @@ def translate_and_calculate_bleu(self, bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" - + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -225,20 +231,20 @@ def init_model_fn( raise ValueError(f'Unknown activation function {self.activation}.') model_config = models.TransformerConfig( - dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu, + ) self._train_model = models.Transformer(model_config) eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) - params_rng, dropout_rng = jax.random.split(rng) - initial_variables = jax.jit( - self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + params_rng, _ = jax.random.split(rng) + initial_variables = jax.jit(self._eval_model.init)( + {'params': params_rng}, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + ) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) @@ -249,45 +255,54 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: [float] = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm inputs = augmented_and_preprocessed_input_batch.get('inputs', None) targets = augmented_and_preprocessed_input_batch.get('targets', None) inputs_positions = augmented_and_preprocessed_input_batch.get( - 'inputs_position', None) + 'inputs_position', None + ) targets_positions = augmented_and_preprocessed_input_batch.get( - 'targets_position', None) + 'targets_position', None + ) inputs_segmentations = augmented_and_preprocessed_input_batch.get( - 'inputs_segmentation', None) + 'inputs_segmentation', None + ) targets_segmentations = augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None) + 'targets_segmentation', None + ) if mode == spec.ForwardPassMode.TRAIN: model = self._train_model else: model = self._eval_model - logits_batch = model.apply({'params': params}, - inputs, - targets, - inputs_positions=inputs_positions, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentations, - targets_segmentation=targets_segmentations, - rngs={'dropout': rng}) + logits_batch = model.apply( + {'params': params}, + inputs, + targets, + inputs_positions=inputs_positions, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentations, + targets_segmentation=targets_segmentations, + rngs={'dropout': rng}, + dropout_rate=dropout_rate, + ) return logits_batch, None def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples eval_denominator = total_metrics.pop('denominator') diff --git a/algoperf/workloads/wmt/wmt_pytorch/decode.py b/algoperf/workloads/wmt/wmt_pytorch/decode.py index 26ff36650..7974412d7 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/decode.py +++ b/algoperf/workloads/wmt/wmt_pytorch/decode.py @@ -21,8 +21,9 @@ NEG_INF = torch.tensor(-1.0e7, device=DEVICE) -def brevity_penalty(alpha: float, length: Union[int, - torch.Tensor]) -> torch.Tensor: +def brevity_penalty( + alpha: float, length: Union[int, torch.Tensor] +) -> torch.Tensor: """Brevity penalty function for beam search penalizing short sequences. Args: @@ -57,8 +58,9 @@ def flatten_beam_dim(x: torch.Tensor) -> torch.Tensor: return x.view(-1, *x.shape[2:]) -def unflatten_beam_dim(x: torch.Tensor, batch_size: int, - beam_size: int) -> torch.Tensor: +def unflatten_beam_dim( + x: torch.Tensor, batch_size: int, beam_size: int +) -> torch.Tensor: """Unflattens the first, flat batch*beam dimension of a non-scalar tensor.""" if x.dim() < 2: # ignore scalars (e.g. cache index) return x @@ -71,10 +73,12 @@ def flat_batch_beam_expand(x: torch.Tensor, beam_size: int) -> torch.Tensor: return flatten_beam_dim(add_beam_dim(x, beam_size)) -def gather_beams(nested: Dict[str, Any], - beam_indices: torch.Tensor, - batch_size: int, - new_beam_size: int) -> Dict[str, Any]: +def gather_beams( + nested: Dict[str, Any], + beam_indices: torch.Tensor, + batch_size: int, + new_beam_size: int, +) -> Dict[str, Any]: """Gathers the beam slices indexed by beam_indices into new beam tensor. Args: @@ -88,10 +92,13 @@ def gather_beams(nested: Dict[str, Any], [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] """ batch_indices = torch.reshape( - torch.div( - torch.arange(batch_size * new_beam_size, device=DEVICE), - new_beam_size, - rounding_mode='floor'), (batch_size, new_beam_size)) + torch.div( + torch.arange(batch_size * new_beam_size, device=DEVICE), + new_beam_size, + rounding_mode='floor', + ), + (batch_size, new_beam_size), + ) def gather_fn(x): if x.dim() < 2: # ignore scalars (e.g. cache index) @@ -101,10 +108,12 @@ def gather_fn(x): return jax.tree.map(gather_fn, nested) -def gather_topk_beams(nested: Dict[str, Any], - score_or_log_prob: torch.Tensor, - batch_size: int, - new_beam_size: int) -> Dict[str, Any]: +def gather_topk_beams( + nested: Dict[str, Any], + score_or_log_prob: torch.Tensor, + batch_size: int, + new_beam_size: int, +) -> Dict[str, Any]: """Gathers the top-k beam slices given by score_or_log_prob array. Args: @@ -129,6 +138,7 @@ def gather_topk_beams(nested: Dict[str, Any], @dataclass class BeamState: """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. cur_index: torch.Tensor # scalar int32: current decoded length index. # The active sequence log probabilities and finished sequence scores. @@ -143,49 +153,52 @@ class BeamState: cache: Dict[str, Any] # Any dict (of dicts), with torch.Tensors as leafs. -def beam_init(batch_size: int, - beam_size: int, - max_decode_len: int, - cache: Dict[str, Any]) -> BeamState: +def beam_init( + batch_size: int, beam_size: int, max_decode_len: int, cache: Dict[str, Any] +) -> BeamState: """Initializes the beam search state data structure.""" cur_index0 = torch.tensor(0, device=DEVICE) live_logprobs0 = torch.tile( - torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE), - [batch_size, 1]) + torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE), + [batch_size, 1], + ) finished_scores0 = ( - torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF) - live_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len), - dtype=torch.int32, - device=DEVICE) - finished_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len), - dtype=torch.int32, - device=DEVICE) - finished_flags0 = torch.zeros((batch_size, beam_size), - dtype=torch.bool, - device=DEVICE) + torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF + ) + live_seqs0 = torch.zeros( + (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE + ) + finished_seqs0 = torch.zeros( + (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE + ) + finished_flags0 = torch.zeros( + (batch_size, beam_size), dtype=torch.bool, device=DEVICE + ) # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( - cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0) + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) # Beam search routine: def beam_search( - inputs: torch.Tensor, - cache: Optional[Dict[str, Any]], - tokens_to_logits: Callable, - beam_size: int = 4, - alpha: float = 0.6, - eos_id: int = EOS_ID, - max_decode_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + inputs: torch.Tensor, + cache: Optional[Dict[str, Any]], + tokens_to_logits: Callable, + beam_size: int = 4, + alpha: float = 0.6, + eos_id: int = EOS_ID, + max_decode_len: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: """Beam search for transformer machine translation. Args: @@ -211,10 +224,9 @@ def beam_search( end_marker = torch.tensor(eos_id, device=DEVICE) # initialize beam search state - beam_search_init_state = beam_init(batch_size, - beam_size, - max_decode_len, - cache) + beam_search_init_state = beam_init( + batch_size, beam_size, max_decode_len, cache + ) def beam_search_loop_cond_fn(state: BeamState) -> bool: """Beam search loop termination condition.""" @@ -227,11 +239,12 @@ def beam_search_loop_cond_fn(state: BeamState) -> bool: best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. worst_finished_scores, _ = torch.min( - state.finished_scores, dim=1, keepdim=True) + state.finished_scores, dim=1, keepdim=True + ) # Mask out scores from slots without any actual finished sequences. - worst_finished_scores = torch.where(state.finished_flags, - worst_finished_scores, - NEG_INF) + worst_finished_scores = torch.where( + state.finished_flags, worst_finished_scores, NEG_INF + ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = torch.all(worst_finished_scores > best_live_scores) @@ -248,7 +261,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # --> [batch * beam, 1] cur_index = state.cur_index flat_ids = flatten_beam_dim( - state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1]) + state.live_seqs[:batch_size, :beam_size, cur_index : cur_index + 1] + ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree.map(flatten_beam_dim, state.cache) @@ -263,7 +277,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree.map( - lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) + lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache + ) # Gather log probabilities from logits candidate_log_probs = F.log_softmax(logits, dim=-1) @@ -287,13 +302,13 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_log_probs, topk_indices = torch.topk(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = torch.div( - topk_indices, vocab_size, rounding_mode='floor') + topk_indices, vocab_size, rounding_mode='floor' + ) # Gather 2*k top beams. # --> [batch, 2*beams, length] - topk_seq = gather_beams(state.live_seqs, - topk_beam_indices, - batch_size, - beams_to_keep) + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. @@ -301,11 +316,11 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_ids = torch.unsqueeze(topk_indices % vocab_size, dim=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] - topk_seq[:, :, cur_index + 1:] = topk_ids + topk_seq[:, :, cur_index + 1 :] = topk_ids # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] - newly_finished = (topk_seq[:, :, cur_index + 1] == end_marker) + newly_finished = topk_seq[:, :, cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -316,22 +331,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: new_topk_indices = torch.flip(new_topk_indices, (1,)) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] - top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], - new_topk_indices, - batch_size, beam_size) + top_alive_seq, top_alive_log_probs = gather_beams( + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] - top_alive_indices = gather_beams(topk_beam_indices, - new_topk_indices, - batch_size, - beam_size) + top_alive_indices = gather_beams( + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} - top_alive_cache = gather_beams(new_cache, - top_alive_indices, - batch_size, - beam_size) + top_alive_cache = gather_beams( + new_cache, top_alive_indices, batch_size, beam_size + ) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. @@ -344,24 +357,33 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = torch.cat( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], dim=1) + [state.finished_seqs, topk_seq], dim=1 + ) finished_scores = torch.cat( # --> [batch, 3*beams] - [state.finished_scores, new_scores], dim=1) + [state.finished_scores, new_scores], dim=1 + ) finished_flags = torch.cat( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], dim=1) + [state.finished_flags, newly_finished], dim=1 + ) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams([finished_seqs, finished_scores, finished_flags], - finished_scores, batch_size, beam_size)) + gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + ) return BeamState( - cur_index=cur_index + 1, - live_logprobs=top_alive_log_probs, - finished_scores=top_finished_scores, - live_seqs=top_alive_seq, - finished_seqs=top_finished_seq, - finished_flags=top_finished_flags, - cache=top_alive_cache) + cur_index=cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache, + ) state = beam_search_init_state while beam_search_loop_cond_fn(state): @@ -373,12 +395,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # --> [batch] none_finished = torch.any(final_state.finished_flags, dim=1) # --> [batch, beams, length] - finished_seqs = torch.where(none_finished[:, None, None], - final_state.finished_seqs, - final_state.live_seqs) + finished_seqs = torch.where( + none_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) # --> [batch, beams] - finished_scores = torch.where(none_finished[:, None], - final_state.finished_scores, - final_state.live_logprobs) + finished_scores = torch.where( + none_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) return finished_seqs, finished_scores diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..430cc945b 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/models.py +++ b/algoperf/workloads/wmt/wmt_pytorch/models.py @@ -3,11 +3,11 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from torch import nn -from torch import Tensor import torch.nn.functional as F -from torch.nn.init import normal_ -from torch.nn.init import xavier_uniform_ +from torch import Tensor, nn +from torch.nn.init import normal_, xavier_uniform_ + +DROPOUT_RATE = 0.1 def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: @@ -21,7 +21,8 @@ def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor: A `[batch..., len, len]` shaped causal attention mask. """ idxs = torch.broadcast_to( - torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) + torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape + ) return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2)) @@ -31,55 +32,60 @@ def make_src_mask(src, inputs_segmentation, nhead): # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: src_mask = torch.logical_and( - src_mask, - torch.eq( - inputs_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) + src_mask, + torch.eq( + inputs_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2) + ), + ) # Flip values and ensure numerical stability. src_mask = torch.repeat_interleave( - torch.logical_not(src_mask), repeats=nhead, dim=0) + torch.logical_not(src_mask), repeats=nhead, dim=0 + ) new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32) new_src_mask.masked_fill_(src_mask, -1e10) return new_src_mask -def make_tgt_and_memory_mask(tgt, - src, - inputs_segmentation, - targets_segmentation, - decode, - nhead): - """ Utility for creating target and memory mask and adjust them for PyTorch +def make_tgt_and_memory_mask( + tgt, src, inputs_segmentation, targets_segmentation, decode, nhead +): + """Utility for creating target and memory mask and adjust them for PyTorch Transformer API.""" if not decode: tgt_mask = torch.logical_and( - torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), - make_causal_mask(tgt, device=tgt.device)) + torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)), + make_causal_mask(tgt, device=tgt.device), + ) memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)) else: tgt_mask = None - memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), - (src > 0).unsqueeze(-2)) + memory_mask = torch.mul( + (torch.ones_like(tgt) > 0).unsqueeze(-1), (src > 0).unsqueeze(-2) + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: tgt_mask = torch.logical_and( - tgt_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - targets_segmentation.unsqueeze(-2))) + tgt_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), targets_segmentation.unsqueeze(-2) + ), + ) memory_mask = torch.logical_and( - memory_mask, - torch.eq( - targets_segmentation.unsqueeze(-1), - inputs_segmentation.unsqueeze(-2))) + memory_mask, + torch.eq( + targets_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2) + ), + ) # Flip values and ensure numerical stability. memory_mask = torch.repeat_interleave( - torch.logical_not(memory_mask), repeats=nhead, dim=0) + torch.logical_not(memory_mask), repeats=nhead, dim=0 + ) new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32) new_memory_mask.masked_fill_(memory_mask, -1e10) if tgt_mask is not None: tgt_mask = torch.repeat_interleave( - torch.logical_not(tgt_mask), repeats=nhead, dim=0) + torch.logical_not(tgt_mask), repeats=nhead, dim=0 + ) new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32) new_tgt_mask.masked_fill_(tgt_mask, -1e10) tgt_mask = new_tgt_mask @@ -98,48 +104,44 @@ def shift_right(x, axis=1): class Transformer(nn.Module): """Transformer architecture based on the model from the WMT Jax workload.""" - def __init__(self, - ntoken: int = 32000, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - dropout_rate: Optional[float] = 0.1, - attention_dropout_rate: Optional[float] = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): + def __init__( + self, + ntoken: int = 32000, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + ): super().__init__() - if dropout_rate is None: - dropout_rate = 0.1 - if attention_dropout_rate is None: - attention_dropout_rate = 0.1 - self.pos_encoder = PositionalEncoding(d_model, dropout_rate) + self.pos_encoder = PositionalEncoding(d_model) self.shared_embedding = nn.Embedding(ntoken, d_model) - self.encoder = Encoder(d_model, - nhead, - d_hid, - nlayers, - dropout_rate, - attention_dropout_rate, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) - self.decoder = Decoder(d_model, - nhead, - d_hid, - nlayers, - dropout_rate, - attention_dropout_rate, - activation, - glu, - layer_norm_eps, - attention_temp, - pre_ln) + self.encoder = Encoder( + d_model, + nhead, + d_hid, + nlayers, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln, + ) + self.decoder = Decoder( + d_model, + nhead, + d_hid, + nlayers, + activation, + glu, + layer_norm_eps, + attention_temp, + pre_ln, + ) # Share positional encoding and embedding between encoder and decoder. self.encoder.pos_encoder = self.pos_encoder self.encoder.shared_embedding = self.shared_embedding @@ -156,14 +158,17 @@ def _reset_parameters(self): if module.bias is not None: normal_(module.bias, std=1e-6) - def forward(self, - src: Tensor, - tgt: Tensor, - inputs_positions: Optional[Tensor] = None, - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False) -> Tensor: + def forward( + self, + src: Tensor, + tgt: Tensor, + inputs_positions: Optional[Tensor] = None, + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False, + dropout_rate: float = DROPOUT_RATE, + ) -> Tensor: """ Args: src: Tensor, shape [batch_size, seq_len] @@ -173,24 +178,30 @@ def forward(self, inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len] targets_segmentation: Optional[Tensor], shape [batch_size, seq_len] decode: bool + dropout_rate: float Returns: output Tensor of shape [batch_size, seq_len, ntoken] """ if src.size(0) != tgt.size(0): raise RuntimeError('The batch size of src and tgt must be equal.') + memory = self.encoder( - src, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation) + src, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + dropout_rate=dropout_rate, + ) output = self.decoder( - tgt, - memory, - src, # just for calculating the padding mask - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - decode=decode) + tgt, + memory, + src, # just for calculating the padding mask + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + decode=decode, + dropout_rate=dropout_rate, + ) return output @@ -213,28 +224,38 @@ class TransformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ + __constants__ = ['norm'] - def __init__(self, - encoder_layer, - num_layers, - norm=None, - enable_nested_tensor=True, - mask_check=True): + def __init__( + self, + encoder_layer, + num_layers, + norm=None, + enable_nested_tensor=True, + mask_check=True, + ): super().__init__() self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) self.num_layers = num_layers self.norm = norm self.enable_nested_tensor = enable_nested_tensor self.mask_check = mask_check - def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: """Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). + dropout_rate: the dropout probability (optional). Shape: see the docs in Transformer class. @@ -243,10 +264,10 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: convert_to_nested = False for mod in self.layers: - output = mod(output, src_mask=mask) + output = mod(output, src_mask=mask, dropout_rate=dropout_rate) if convert_to_nested: - output = output.to_padded_tensor(0.) + output = output.to_padded_tensor(0.0) if self.norm is not None: output = self.norm(output) @@ -255,109 +276,118 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor: class Encoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + ): super().__init__() self.nhead = nhead self.shared_embedding = None self.pos_encoder = None encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - d_hid, - dropout_rate, - attention_dropout_rate=attention_dropout_rate, - activation=activation, - glu=glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) - encoder_norm = ( - nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) + d_model, + nhead, + d_hid, + activation=activation, + glu=glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln, + ) + encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm) - def forward(self, - src: Tensor, - inputs_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None) -> Tensor: + def forward( + self, + src: Tensor, + inputs_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: src = src.to(torch.int) src_mask = make_src_mask(src, inputs_segmentation, self.nhead) src = self.shared_embedding(src) - src = self.pos_encoder(src, inputs_positions) - memory = self.encoder(src, mask=src_mask) + src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate) + memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate) return memory class Decoder(nn.Module): - - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - d_hid: int = 1024, - nlayers: int = 6, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True): + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + d_hid: int = 1024, + nlayers: int = 6, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + ): super().__init__() self.nhead = nhead self.shared_embedding = None self.pos_encoder = None - self.decoder = TransformerDecoder(d_model, - nhead, - d_hid, - dropout_rate, - attention_dropout_rate, - activation, - glu, - layer_norm_eps, - nlayers, - attention_temp, - pre_ln) + self.decoder = TransformerDecoder( + d_model, + nhead, + d_hid, + activation, + glu, + layer_norm_eps, + nlayers, + attention_temp, + pre_ln, + ) def forward( - self, - tgt: Tensor, - memory: Tensor, - src: Tensor, # just for calculating the padding mask - targets_positions: Optional[Tensor] = None, - inputs_segmentation: Optional[Tensor] = None, - targets_segmentation: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + self, + tgt: Tensor, + memory: Tensor, + src: Tensor, # just for calculating the padding mask + targets_positions: Optional[Tensor] = None, + inputs_segmentation: Optional[Tensor] = None, + targets_segmentation: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: tgt = tgt.to(torch.int) tgt_mask, memory_mask = make_tgt_and_memory_mask( - tgt, src, inputs_segmentation, targets_segmentation, - decode, self.nhead) + tgt, src, inputs_segmentation, targets_segmentation, decode, self.nhead + ) if not decode: tgt = shift_right(tgt) tgt = self.shared_embedding(tgt) - tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache) + tgt = self.pos_encoder( + tgt, + targets_positions, + decode=decode, + cache=cache, + dropout_rate=dropout_rate, + ) if decode: tgt, cache = tgt output = self.decoder( - tgt, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache) + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache, + dropout_rate=dropout_rate, + ) if decode: output, cache = output normalize = math.sqrt(output.shape[-1]) @@ -368,28 +398,24 @@ def forward( class PositionalEncoding(nn.Module): - - def __init__(self, - d_model: int, - dropout_rate: float = 0.1, - max_len: int = 256): + def __init__(self, d_model: int, max_len: int = 256): super().__init__() - self.dropout = nn.Dropout(p=dropout_rate) position = torch.arange(max_len).unsqueeze(1) scale_factor = -math.log(10000.0) / (d_model // 2 - 1) div_term = torch.exp(torch.arange(d_model // 2) * scale_factor) pe = torch.zeros(1, max_len, d_model) - pe[0, :, :d_model // 2] = torch.sin(position * div_term) - pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term) + pe[0, :, : d_model // 2] = torch.sin(position * div_term) + pe[0, :, d_model // 2 : 2 * (d_model // 2)] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward( - self, - x: Tensor, - inputs_positions: Optional[Tensor] = None, - decode: bool = False, - cache: Optional[Dict[str, Dict[str, Tensor]]] = None + self, + x: Tensor, + inputs_positions: Optional[Tensor] = None, + decode: bool = False, + cache: Optional[Dict[str, Dict[str, Tensor]]] = None, + dropout_rate: Optional[float] = 0.0, ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]: """ Args: @@ -397,6 +423,7 @@ def forward( inputs_positions: Tensor (shape [batch_size, seq_len]) or None decode: bool cache: Dict[str, Dict[str, Tensor]] or None + dropout_rate: Optional[float] Returns: Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]] """ @@ -405,21 +432,22 @@ def forward( name = self._get_name() if cache is None: cache = { - name: { - 'cache_index': - torch.tensor(0, dtype=torch.long, device=self.pe.device), - }, + name: { + 'cache_index': torch.tensor( + 0, dtype=torch.long, device=self.pe.device + ), + }, } pe = self.pe[0, cache[name]['cache_index'], :] cache[name]['cache_index'] += 1 - return self.dropout(x + pe), cache + return F.dropout(x + pe, dropout_rate, self.training), cache if inputs_positions is None: # normal unpacked case: - pe = self.pe[:, :x.size(1), :] + pe = self.pe[:, : x.size(1), :] else: # for packed data we need to use known position indices: pe = self.pe[0, inputs_positions, :] - return self.dropout(x + pe) + return F.dropout(x + pe, dropout_rate, self.training) # TransformerEncoderLayer and TransformerDecoderLayer are taken from: @@ -438,7 +466,6 @@ class TransformerEncoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -451,81 +478,91 @@ class TransformerEncoderLayer(nn.Module): >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ + __constants__ = ['pre_ln'] - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - attention_temp: float = 1.0, - pre_ln: bool = True, - device=None, - dtype=None) -> None: + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + attention_temp: float = 1.0, + pre_ln: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - dropout_rate=attention_dropout_rate, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) + d_model, + nhead, + self_attn=True, + attention_temp=attention_temp, + bias=False, + **factory_kwargs, + ) # Implementation of Feedforward model. self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.glu = glu if self.glu: self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) self.activation = activation - def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor: + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - + dropout_rate: the dropout probability value (optional). Shape: see the docs in Transformer class. """ x = src if self.pre_ln: - x = x + self._sa_block(self.norm1(x), src_mask) - x = x + self._ff_block(self.norm2(x)) + x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate) + x = x + self._ff_block(self.norm2(x), dropout_rate) else: - x = self.norm1(x + self._sa_block(x, src_mask)) - x = self.norm2(x + self._ff_block(x)) + x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate)) + x = self.norm2(x + self._ff_block(x, dropout_rate)) return x # Self-attention block: - def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.self_attn(x, attn_mask=attn_mask) - return self.dropout1(x) + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: + x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate) + return F.dropout(x, dropout_rate, training=self.training) # Feed forward block: - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block( + self, inputs: Tensor, dropout_rate: Optional[float] = 0.0 + ) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout2(x) + x = self.linear2(F.dropout(x, dropout_rate, training=self.training)) + return F.dropout(x, dropout_rate, training=self.training) # Modified to use cache for autoregressive decoding and custom @@ -537,7 +574,6 @@ class TransformerDecoder(nn.Module): nhead: the number of heads in the multiheadattention models (default=16) d_hid: the dimension of the feedforward network model (default=1024) - dropout_rate: the dropout_rate value (default=0.1) layer_norm_eps: the eps value in layer normalization components (default=1e-6). decoder_layer: an instance of the TransformerDecoderLayer() class @@ -549,45 +585,51 @@ class TransformerDecoder(nn.Module): >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory) """ + __constants__ = ['norm'] - def __init__(self, - d_model, - nhead, - d_hid, - dropout_rate, - attention_dropout_rate, - activation, - glu, - layer_norm_eps, - num_layers, - attention_temp, - pre_ln): + def __init__( + self, + d_model, + nhead, + d_hid, + activation, + glu, + layer_norm_eps, + num_layers, + attention_temp, + pre_ln, + ): super().__init__() - self.layers = nn.ModuleList([ + self.layers = nn.ModuleList( + [ TransformerDecoderLayer( - d_model, - nhead, - d_hid, - dropout_rate, - attention_dropout_rate, - activation, - glu, - layer_norm_eps=layer_norm_eps, - attention_temp=attention_temp, - pre_ln=pre_ln) for _ in range(num_layers) - ]) + d_model, + nhead, + d_hid, + activation, + glu, + layer_norm_eps=layer_norm_eps, + attention_temp=attention_temp, + pre_ln=pre_ln, + ) + for _ in range(num_layers) + ] + ) self.num_layers = num_layers - self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None) - - def forward(self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None) -> Any: + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). @@ -596,6 +638,7 @@ def forward(self, memory_mask: the mask for the memory sequence (optional). decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -603,14 +646,16 @@ def forward(self, for idx, mod in enumerate(self.layers): output, cache = mod( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=idx) + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=idx, + dropout_rate=dropout_rate, + ) if self.norm is not None: output = self.norm(output) @@ -636,7 +681,6 @@ class TransformerDecoderLayer(nn.Module): nhead: the number of heads in the multiheadattention models (default=16). dim_feedforward: the dimension of the feedforward network model (default=1024). - dropout_rate: the dropout_rate value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable (default=F.relu). layer_norm_eps: the eps value in layer normalization components @@ -650,70 +694,69 @@ class TransformerDecoderLayer(nn.Module): >>> tgt = torch.rand(32, 20, 512) >>> out = decoder_layer(tgt, memory) """ + __constants__ = ['pre_ln'] - def __init__(self, - d_model: int = 1024, - nhead: int = 16, - dim_feedforward: int = 1024, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - glu: bool = False, - layer_norm_eps: float = 1e-6, - pre_ln: bool = True, - attention_temp: float = 1.0, - device=None, - dtype=None) -> None: + def __init__( + self, + d_model: int = 1024, + nhead: int = 16, + dim_feedforward: int = 1024, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + glu: bool = False, + layer_norm_eps: float = 1e-6, + pre_ln: bool = True, + attention_temp: float = 1.0, + device=None, + dtype=None, + ) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention( - d_model, - nhead, - self_attn=True, - dropout_rate=attention_dropout_rate, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) + d_model, + nhead, + self_attn=True, + attention_temp=attention_temp, + bias=False, + **factory_kwargs, + ) self.multihead_attn = MultiheadAttention( - d_model, - nhead, - self_attn=False, - dropout_rate=attention_dropout_rate, - attention_temp=attention_temp, - bias=False, - **factory_kwargs) + d_model, + nhead, + self_attn=False, + attention_temp=attention_temp, + bias=False, + **factory_kwargs, + ) # Implementation of Feedforward model. self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.glu = glu if self.glu: - self.linear_glu = nn.Linear(dim_feedforward, - dim_feedforward, - **factory_kwargs) - self.dropout = nn.Dropout(dropout_rate) + self.linear_glu = nn.Linear( + dim_feedforward, dim_feedforward, **factory_kwargs + ) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.pre_ln = pre_ln self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout_rate) - self.dropout2 = nn.Dropout(dropout_rate) - self.dropout3 = nn.Dropout(dropout_rate) self.activation = activation def forward( # pylint: disable=arguments-renamed - self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). @@ -722,6 +765,7 @@ def forward( # pylint: disable=arguments-renamed memory_mask: the mask for the memory sequence (optional). decode: wether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. + dropout_rate: the dropout probability value (optional) Shape: see the docs in Transformer class. """ @@ -730,61 +774,78 @@ def forward( # pylint: disable=arguments-renamed x = tgt if self.pre_ln: sa_out, cache = self._sa_block( - self.norm1(x), - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index) + self.norm1(x), + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index, + dropout_rate=dropout_rate, + ) x = x + sa_out - x = x + self._mha_block(self.norm2(x), memory, memory_mask) - x = x + self._ff_block(self.norm3(x)) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate) + x = x + self._ff_block(self.norm3(x), dropout_rate) else: sa_out, cache = self._sa_block( - x, - tgt_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index) + x, + tgt_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index, + dropout_rate=dropout_rate, + ) x = self.norm1(x + sa_out) - x = self.norm2(x + self._mha_block(x, memory, memory_mask)) - x = self.norm3(x + self._ff_block(x)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate)) + x = self.norm3(x + self._ff_block(x, dropout_rate)) return x, cache # Self-attention block: def _sa_block( # pylint: disable=arguments-renamed - self, - x: Tensor, - attn_mask: Optional[Tensor], - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + self, + x: Tensor, + attn_mask: Optional[Tensor], + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: x, cache = self.self_attn( - x, - attn_mask=attn_mask, - decode=decode, - max_len=max_len, - cache=cache, - index=index) - return self.dropout1(x), cache + x, + attn_mask=attn_mask, + decode=decode, + max_len=max_len, + cache=cache, + index=index, + dropout_rate=dropout_rate, + ) + return F.dropout(x, dropout_rate, self.training), cache # Multihead attention block: - def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor]) -> Tensor: - x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask) - return self.dropout2(x) + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + dropout_rate: Optional[float] = 0.0, + ) -> Tensor: + x, _ = self.multihead_attn( + x, mem, attn_mask=attn_mask, dropout_rate=dropout_rate + ) + return F.dropout(x, dropout_rate, self.training) # Feed forward block. - def _ff_block(self, inputs: Tensor) -> Tensor: + def _ff_block( + self, inputs: Tensor, dropout_rate: Optional[float] = 0.0 + ) -> Tensor: x = self.activation(self.linear1(inputs)) if self.glu: y = self.linear_glu(inputs) x = x * y - x = self.linear2(self.dropout(x)) - return self.dropout3(x) + x = self.linear2(F.dropout(x, dropout_rate, self.training)) + return F.dropout(x, dropout_rate, self.training) class MultiheadAttention(nn.Module): @@ -802,8 +863,6 @@ class MultiheadAttention(nn.Module): ``embed_dim // num_heads``). self_attn: Whether self attention or encoder-decoder attention is used. Default: ``True``. - dropout_rate: Dropout probability on ``attn_output_weights``. - Default: ``0.0`` (no dropout_rate). bias: If specified, adds bias to input / output projection layers. Default: ``False``. device: The device of the module. @@ -813,35 +872,38 @@ class MultiheadAttention(nn.Module): >>> attn_output, cache = multihead_attn(x) """ - def __init__(self, - embed_dim: int, - num_heads: int, - self_attn: bool = True, - dropout_rate: float = 0., - attention_temp: float = 1.0, - bias: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None) -> None: + def __init__( + self, + embed_dim: int, + num_heads: int, + self_attn: bool = True, + attention_temp: float = 1.0, + bias: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.self_attn = self_attn - self.dropout = dropout_rate self.head_dim = embed_dim // num_heads self.attention_temp = attention_temp - assert self.head_dim * num_heads == self.embed_dim, \ - 'embed_dim must be divisible by num_heads.' + assert self.head_dim * num_heads == self.embed_dim, ( + 'embed_dim must be divisible by num_heads.' + ) factory_kwargs = {'device': device, 'dtype': dtype} if self_attn: # Self-attention. self.in_proj = nn.Linear( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) else: # Encoder-decoder attention. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.kv_proj = nn.Linear( - embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) + embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs + ) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self._reset_parameters() @@ -854,14 +916,17 @@ def _reset_parameters(self): if module.bias is not None: normal_(module.bias, std=1e-6) - def forward(self, - x: Tensor, - mem: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - decode: bool = False, - max_len: Optional[int] = None, - cache: Optional[dict] = None, - index: Optional[int] = None) -> Any: + def forward( + self, + x: Tensor, + mem: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + decode: bool = False, + max_len: Optional[int] = None, + cache: Optional[dict] = None, + index: Optional[int] = None, + dropout_rate: Optional[float] = 0.0, + ) -> Any: # TODO: (nico) remove default?! r""" Args: x: Batch of input sequences of shape @@ -869,7 +934,7 @@ def forward(self, attention mechanism. See "Attention Is All You Need" for more details. mem: Batch of input sequences of shape (batch size, sequence length, embedding dimensionality) for - encoder-decoder attention. See "Attention Is All You Need" for more + encoder-decoder attention. See "Attention Is All You Need" for more details. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or @@ -887,6 +952,7 @@ def forward(self, max_len: maximum sequence length, necessary for decoding cache. cache: cache dictionary for autoregressive decoding. index: index of the current decoding step, necessary for decoding cache. + dropout_rate: dropout probability on ``attn_output_weights``. Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`L` is the target sequence length, :math:`N` is the batch size, @@ -911,16 +977,13 @@ def forward(self, if decode: if loc_cache is None: loc_cache = { - 'cached_key': - torch.zeros((bsz, max_len, embed_dim), - dtype=k.dtype, - device=k.device), - 'cached_value': - torch.zeros((bsz, max_len, embed_dim), - dtype=v.dtype, - device=v.device), - 'cache_index': - torch.tensor(0, dtype=torch.long, device=k.device), + 'cached_key': torch.zeros( + (bsz, max_len, embed_dim), dtype=k.dtype, device=k.device + ), + 'cached_value': torch.zeros( + (bsz, max_len, embed_dim), dtype=v.dtype, device=v.device + ), + 'cache_index': torch.tensor(0, dtype=torch.long, device=k.device), } cached_key = loc_cache['cached_key'] cached_value = loc_cache['cached_value'] @@ -928,11 +991,13 @@ def forward(self, # Shape check of cached keys against query input. expected_shape = (bsz, 1, embed_dim) if expected_shape != x.shape: - raise ValueError('Autoregressive cache shape error, expected query ' - f'shape {expected_shape} instead got {x.shape}.') + raise ValueError( + 'Autoregressive cache shape error, expected query ' + f'shape {expected_shape} instead got {x.shape}.' + ) # Update key, value caches with our new 1d spatial slices. - cached_key[:, cache_index:cache_index + 1, :] = k - cached_value[:, cache_index:cache_index + 1, :] = v + cached_key[:, cache_index : cache_index + 1, :] = k + cached_value[:, cache_index : cache_index + 1, :] = v k = cached_key v = cached_value cache_index += 1 @@ -942,8 +1007,9 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = ( + torch.arange(max_len, device=k.device) >= cache_index + ).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) @@ -955,17 +1021,21 @@ def forward(self, # Check dtype and shape of attention mask. if not decode and attn_mask is not None: - assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ - f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, ( + f'Float and bool dtypes are supported, not {attn_mask.dtype}.' + ) # Ensure attn_mask's dim is 3. if attn_mask.dim() == 3: correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len) if attn_mask.shape != correct_3d_size: - raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, ' - f'but should be {correct_3d_size}.') + raise RuntimeError( + f'The shape of attn_mask is {attn_mask.shape}, ' + f'but should be {correct_3d_size}.' + ) else: raise RuntimeError( - f"attn_mask's dimension {attn_mask.dim()} is not supported") + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) # Reshape attention mask to be consistent with q, k, v. attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len) @@ -976,15 +1046,17 @@ def forward(self, attn_mask = new_attn_mask # Adjust dropout_rate probability. - dropout_rate = self.dropout if self.training else 0.0 + attn_dropout_rate = dropout_rate if self.training else 0.0 # Calculate attention. q = self.attention_temp * q attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask, dropout_rate) + q, k, v, attn_mask, attn_dropout_rate + ) # Rearrange for output projection. - attn_output = attn_output.transpose(1, 2).contiguous().view( - bsz, tgt_len, embed_dim) + attn_output = ( + attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim) + ) # Output projection. attn_output = self.out_proj(attn_output) diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py index d0716d6c8..53d95d393 100644 --- a/algoperf/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -3,20 +3,18 @@ import contextlib from typing import Any, Dict, Optional, Tuple -from absl import logging import jax import tensorflow as tf import torch import torch.distributed as dist -from torch.nn import DataParallel as DP import torch.nn.functional as F +from absl import logging +from torch.nn import DataParallel as DP from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import pytorch_utils -from algoperf import spec +from algoperf import param_utils, pytorch_utils, spec from algoperf.workloads.wmt import bleu -from algoperf.workloads.wmt.wmt_pytorch import decode +from algoperf.workloads.wmt.wmt_pytorch import decode, models from algoperf.workloads.wmt.wmt_pytorch.models import Transformer from algoperf.workloads.wmt.workload import BaseWmtWorkload @@ -27,11 +25,12 @@ class WmtWorkload(BaseWmtWorkload): """WMT PyTorch workload.""" def compute_weighted_cross_entropy( - self, - logits: spec.Tensor, - targets: spec.Tensor, - weights: Optional[spec.Tensor] = None, - label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable + self, + logits: spec.Tensor, + targets: spec.Tensor, + weights: Optional[spec.Tensor] = None, + label_smoothing: float = 0.1, + ) -> Dict[str, spec.Tensor]: # differentiable """Compute weighted cross entropy and entropy for log probs and targets. Args: @@ -46,11 +45,14 @@ def compute_weighted_cross_entropy( valid examples in batch, 'per_example': 1-d array of per-example losses} """ if logits.ndim != targets.ndim + 1: - raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.') + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) loss_fn = torch.nn.CrossEntropyLoss( - reduction='none', label_smoothing=label_smoothing) + reduction='none', label_smoothing=label_smoothing + ) if N_GPUS > 1 and not USE_PYTORCH_DDP: loss_fn = DP(loss_fn) @@ -59,24 +61,27 @@ def compute_weighted_cross_entropy( if weights is None: weights = torch.ones_like(targets) per_example_losses = torch.where( - weights.to(torch.bool), per_example_losses, 0.) + weights.to(torch.bool), per_example_losses, 0.0 + ) summed_loss = per_example_losses.sum() n_valid_examples = weights.sum() return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, } # Primary eval / decode step functions. # ---------------------------------------------------------------------------- @torch.no_grad() - def predict_step(self, - inputs: spec.Tensor, - params: spec.ParameterContainer, - eos_id: int, - max_decode_len: int, - beam_size: int = 4) -> spec.Tensor: + def predict_step( + self, + inputs: spec.Tensor, + params: spec.ParameterContainer, + eos_id: int, + max_decode_len: int, + beam_size: int = 4, + ) -> spec.Tensor: """Predict translation with fast decoding beam search on a batch.""" # params = params.module if isinstance(params, (DP, DDP)) else params if hasattr(params, 'module'): @@ -85,8 +90,8 @@ def predict_step(self, if hasattr(params, '_modules'): params = params._modules - encoder = params["encoder"] - decoder = params["decoder"] + encoder = params['encoder'] + decoder = params['decoder'] else: encoder = params.encoder decoder = params.decoder @@ -97,21 +102,23 @@ def predict_step(self, decoder = DP(decoder) encoded_inputs = torch.repeat_interleave( - encoder(inputs), repeats=beam_size, dim=0) + encoder(inputs), repeats=beam_size, dim=0 + ) raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0) def tokens_ids_to_logits( - flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] + flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] ) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]: """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_flat_cache = decoder( - flat_ids, - encoded_inputs, - raw_inputs, - decode=True, - max_len=max_decode_len, - cache=flat_cache) + flat_ids, + encoded_inputs, + raw_inputs, + decode=True, + max_len=max_decode_len, + cache=flat_cache, + ) # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(dim=1) @@ -120,24 +127,27 @@ def tokens_ids_to_logits( # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( - inputs, - None, - tokens_ids_to_logits, - beam_size=beam_size, - alpha=0.6, - eos_id=eos_id, - max_decode_len=max_decode_len) + inputs, + None, + tokens_ids_to_logits, + beam_size=beam_size, + alpha=0.6, + eos_id=eos_id, + max_decode_len=max_decode_len, + ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:] - def translate_and_calculate_bleu(self, - params: spec.ParameterContainer, - ds_iter: tf.data.Dataset, - num_batches: int, - max_predict_length: int): + def translate_and_calculate_bleu( + self, + params: spec.ParameterContainer, + ds_iter: tf.data.Dataset, + num_batches: int, + max_predict_length: int, + ): """Translates the `ds_iter` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] @@ -145,10 +155,9 @@ def translate_and_calculate_bleu(self, pred_batch = next(ds_iter) inputs = pred_batch['inputs'] targets = pred_batch['targets'] - predicted = self.predict_step(inputs, - params, - decode.EOS_ID, - max_predict_length) + predicted = self.predict_step( + inputs, params, decode.EOS_ID, max_predict_length + ) # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') @@ -165,12 +174,7 @@ def translate_and_calculate_bleu(self, bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score - def init_model_fn( - self, - rng: spec.RandomState, - dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: - """aux_dropout_rate is used as attention_dropout_rate.""" + def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) if self.activation == 'relu': @@ -181,12 +185,11 @@ def init_model_fn( raise ValueError(f'Unknown activation function {self.activation}.') model = Transformer( - dropout_rate=dropout_rate, - attention_dropout_rate=aux_dropout_rate, - pre_ln=self.pre_ln, - attention_temp=self.attention_temp, - activation=activation, - glu=self.glu) + pre_ln=self.pre_ln, + attention_temp=self.attention_temp, + activation=activation, + glu=self.glu, + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -201,13 +204,15 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding.weight' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + dropout_rate: float = models.DROPOUT_RATE, + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm @@ -218,43 +223,53 @@ def model_fn( model.eval() contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): logits_batch = model( - src=augmented_and_preprocessed_input_batch['inputs'], - tgt=augmented_and_preprocessed_input_batch['targets'], - inputs_positions=augmented_and_preprocessed_input_batch.get( - 'inputs_position', None), - targets_positions=augmented_and_preprocessed_input_batch.get( - 'targets_position', None), - inputs_segmentation=augmented_and_preprocessed_input_batch.get( - 'inputs_segmentation', None), - targets_segmentation=augmented_and_preprocessed_input_batch.get( - 'targets_segmentation', None)) + src=augmented_and_preprocessed_input_batch['inputs'], + tgt=augmented_and_preprocessed_input_batch['targets'], + inputs_positions=augmented_and_preprocessed_input_batch.get( + 'inputs_position', None + ), + targets_positions=augmented_and_preprocessed_input_batch.get( + 'targets_position', None + ), + inputs_segmentation=augmented_and_preprocessed_input_batch.get( + 'inputs_segmentation', None + ), + targets_segmentation=augmented_and_preprocessed_input_batch.get( + 'targets_segmentation', None + ), + dropout_rate=dropout_rate, + ) return logits_batch, None - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): per_device_batch_size = int(global_batch_size / N_GPUS) n_inputs = 7 if split == 'train' else 3 # The input pipeline has to be created in all processes, because # self._tokenizer has to be available in every process. - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset) + np_iter = super()._build_input_queue( + data_rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset, + ) # We only need np_iter in one Python process. if RANK != 0: del np_iter @@ -269,14 +284,15 @@ def _build_input_queue(self, tensor = torch.as_tensor(value, dtype=torch.int64, device=DEVICE) tensor_list.append(tensor) batch[key] = ( - tensor[0] if USE_PYTORCH_DDP else tensor.view( - -1, value.shape[-1])) + tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, value.shape[-1]) + ) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: # During eval, the batch size of the remainder might be different. if split != 'train': per_device_batch_size = torch.tensor( - len(batch['inputs']), dtype=torch.int32, device=DEVICE) + len(batch['inputs']), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) # We don't need to broadcast the batch for the device with RANK == 0. dist.broadcast(torch.stack(tensor_list)[:, 1:].contiguous(), src=0) @@ -284,25 +300,27 @@ def _build_input_queue(self, batch = {} # During eval, the batch size of the remainder might be different. if split != 'train': - per_device_batch_size = torch.empty((1,), - dtype=torch.int32, - device=DEVICE) + per_device_batch_size = torch.empty( + (1,), dtype=torch.int32, device=DEVICE + ) dist.broadcast(per_device_batch_size, src=0) # N_GPUS - 1 since we don't broadcast the batch for RANK == 0. - tensor = torch.empty((n_inputs, N_GPUS - 1, per_device_batch_size, 256), - dtype=torch.int64, - device=DEVICE) + tensor = torch.empty( + (n_inputs, N_GPUS - 1, per_device_batch_size, 256), + dtype=torch.int64, + device=DEVICE, + ) dist.broadcast(tensor, src=0) # Note that the order of the keys is important. if split == 'train': keys = [ - 'inputs', - 'inputs_position', - 'inputs_segmentation', - 'targets', - 'targets_position', - 'targets_segmentation', - 'weights', + 'inputs', + 'inputs_position', + 'inputs_segmentation', + 'targets', + 'targets_position', + 'targets_segmentation', + 'weights', ] # For all eval/test splits. else: @@ -312,34 +330,35 @@ def _build_input_queue(self, batch[key] = tensor[n][RANK - 1] yield batch - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + def eval_step( + self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor] + ) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" targets = batch['targets'] weights = batch['weights'] logits, _ = self.model_fn( - params, - batch, - mode=spec.ForwardPassMode.EVAL, - model_state=None, - rng=None, - update_batch_norm=False) - summed_loss = self.compute_weighted_cross_entropy(logits, - targets, - weights, - 0.0)['summed'] + params, + batch, + mode=spec.ForwardPassMode.EVAL, + model_state=None, + rng=None, + update_batch_norm=False, + ) + summed_loss = self.compute_weighted_cross_entropy( + logits, targets, weights, 0.0 + )['summed'] acc_sum, weight_sum = self.compute_weighted_accuracy( - logits, targets, weights) + logits, targets, weights + ) return { - 'loss': summed_loss, - 'accuracy': acc_sum, - 'denominator': weight_sum, + 'loss': summed_loss, + 'accuracy': acc_sum, + 'denominator': weight_sum, } def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples if USE_PYTORCH_DDP: diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 51b33373d..40e4262dd 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -60,8 +60,9 @@ def num_eval_train_examples(self) -> int: # Round up from num_validation_examples (which is the default for # num_eval_train_examples) to the next multiple of eval_batch_size, so that # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) + rounded_up_multiple = math.ceil( + self.num_validation_examples / self.eval_batch_size + ) return rounded_up_multiple * self.eval_batch_size @property @@ -115,23 +116,26 @@ def activation(self) -> str: def glu(self) -> bool: return False - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: jax.random.PRNGKey, + split: str, + data_dir: str, + global_batch_size: int, + num_batches: Optional[int] = None, + repeat_final_dataset: bool = False, + ): is_training = split == 'train' ds, self._tokenizer = input_pipeline.get_wmt_dataset( - data_rng, - split, - data_dir, - is_training=is_training, - vocab_size=self._vocab_size, - global_batch_size=global_batch_size, - num_batches=num_batches, - repeat_final_dataset=repeat_final_dataset) + data_rng, + split, + data_dir, + is_training=is_training, + vocab_size=self._vocab_size, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset, + ) # Separate function is necessary because the code above has to be executed # when _build_input_queue is called (not when next() is first called on it). @@ -148,19 +152,21 @@ def _input_queue_generator(): @abc.abstractmethod def _normalize_eval_metrics( - self, num_examples: int, total_metrics: Dict[str, - Any]) -> Dict[str, float]: + self, num_examples: int, total_metrics: Dict[str, Any] + ) -> Dict[str, float]: """Normalize eval metrics.""" - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: + def _eval_model_on_split( + self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0, + ) -> Dict[str, float]: """Run a full evaluation of the model.""" del model_state del global_step @@ -168,12 +174,13 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset=True) + rng, + split, + data_dir, + global_batch_size, + num_batches, + repeat_final_dataset=True, + ) eval_metrics = {} for _ in range(num_batches): @@ -186,16 +193,17 @@ def _eval_model_on_split(self, eval_results = self._normalize_eval_metrics(num_examples, eval_metrics) eval_results['bleu'] = self.translate_and_calculate_bleu( - params=params, - ds_iter=self._eval_iters[split], - num_batches=num_batches, - max_predict_length=256) + params=params, + ds_iter=self._eval_iters[split], + num_batches=num_batches, + max_predict_length=256, + ) return eval_results def compute_weighted_accuracy( - self, logits: spec.Tensor, targets: spec.Tensor, - weights: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]: + self, logits: spec.Tensor, targets: spec.Tensor, weights: spec.Tensor + ) -> Tuple[spec.Tensor, spec.Tensor]: """Compute weighted accuracy for log probs and targets. Args: @@ -207,8 +215,10 @@ def compute_weighted_accuracy( Tuple of scalar summed accuracy and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: - raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and ' - f'{targets.shape} targets.') + raise ValueError( + f'Incorrect shapes. Got shape {logits.shape} logits and ' + f'{targets.shape} targets.' + ) accuracy = (logits.argmax(-1) == targets) * weights normalizing_factor = weights.sum() return accuracy.sum(), normalizing_factor @@ -216,17 +226,18 @@ def compute_weighted_accuracy( def _decode_tokens(self, toks: spec.Tensor) -> spec.Tensor: if isinstance(toks, torch.Tensor): toks = toks.cpu().numpy() - valid_toks = toks[:np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32) + valid_toks = toks[: np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32) return self._tokenizer.detokenize(valid_toks).numpy().decode('utf-8') # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( - self, - label_batch: spec.Tensor, # Dense or one-hot labels. - logits_batch: spec.Tensor, - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + self, + label_batch: spec.Tensor, # Dense or one-hot labels. + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0, + ) -> Dict[str, spec.Tensor]: # differentiable """Evaluate the (masked) loss function at (label_batch, logits_batch). Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of @@ -234,7 +245,8 @@ def loss_fn( (not synced across devices). """ return self.compute_weighted_cross_entropy( - logits_batch, - label_batch, - weights=mask_batch, - label_smoothing=label_smoothing) + logits_batch, + label_batch, + weights=mask_batch, + label_smoothing=label_smoothing, + ) diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py index 4712f4e25..4dd4717e9 100644 --- a/algoperf/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -1,5 +1,5 @@ -""" Registry of workload info -""" +"""Registry of workload info""" + import importlib import inspect import os @@ -9,149 +9,151 @@ BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { - 'cifar': { - 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload' - }, - 'criteo1tb': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', - }, - 'criteo1tb_test': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', - }, - 'criteo1tb_layernorm': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' - }, - 'criteo1tb_embed_init': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload' - }, - 'criteo1tb_resnet': { - 'workload_path': 'criteo1tb/criteo1tb', - 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' - }, - 'fastmri': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIWorkload', - }, - 'fastmri_model_size': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRIModelSizeWorkload', - }, - 'fastmri_tanh': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRITanhWorkload', - }, - 'fastmri_layernorm': { - 'workload_path': 'fastmri/fastmri', - 'workload_class_name': 'FastMRILayerNormWorkload', - }, - 'imagenet_resnet': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetWorkload', - }, - 'imagenet_resnet_silu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetSiLUWorkload', - }, - 'imagenet_resnet_gelu': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetGELUWorkload', - }, - 'imagenet_resnet_large_bn_init': { - 'workload_path': 'imagenet_resnet/imagenet', - 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', - }, - 'imagenet_vit': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitWorkload', - }, - 'imagenet_vit_glu': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitGluWorkload', - }, - 'imagenet_vit_post_ln': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitPostLNWorkload', - }, - 'imagenet_vit_map': { - 'workload_path': 'imagenet_vit/imagenet', - 'workload_class_name': 'ImagenetVitMapWorkload', - }, - 'librispeech_conformer': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerWorkload', - }, - 'librispeech_conformer_attention_temperature': { - 'workload_path': - 'librispeech_conformer/librispeech', - 'workload_class_name': - 'LibriSpeechConformerAttentionTemperatureWorkload', - }, - 'librispeech_conformer_layernorm': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', - }, - 'librispeech_conformer_gelu': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerGeluWorkload', - }, - 'librispeech_deepspeech': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', - }, - 'librispeech_deepspeech_tanh': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', - }, - 'librispeech_deepspeech_no_resnet': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', - }, - 'librispeech_deepspeech_norm_and_spec_aug': { - 'workload_path': 'librispeech_deepspeech/librispeech', - 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', - }, - 'mnist': { - 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' - }, - 'ogbg': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' - }, - 'ogbg_gelu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' - }, - 'ogbg_silu': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' - }, - 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', - 'workload_class_name': 'OgbgModelSizeWorkload' - }, - 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, - 'wmt_post_ln': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' - }, - 'wmt_attention_temp': { - 'workload_path': 'wmt/wmt', - 'workload_class_name': 'WmtWorkloadAttentionTemp' - }, - 'wmt_glu_tanh': { - 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH' - }, + 'cifar': { + 'workload_path': 'cifar/cifar', + 'workload_class_name': 'CifarWorkload', + }, + 'criteo1tb': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallWorkload', + }, + 'criteo1tb_test': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + }, + 'criteo1tb_layernorm': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload', + }, + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload', + }, + 'criteo1tb_resnet': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload', + }, + 'fastmri': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIWorkload', + }, + 'fastmri_model_size': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRIModelSizeWorkload', + }, + 'fastmri_tanh': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRITanhWorkload', + }, + 'fastmri_layernorm': { + 'workload_path': 'fastmri/fastmri', + 'workload_class_name': 'FastMRILayerNormWorkload', + }, + 'imagenet_resnet': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetWorkload', + }, + 'imagenet_resnet_silu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetSiLUWorkload', + }, + 'imagenet_resnet_gelu': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetGELUWorkload', + }, + 'imagenet_resnet_large_bn_init': { + 'workload_path': 'imagenet_resnet/imagenet', + 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload', + }, + 'imagenet_vit': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitWorkload', + }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitMapWorkload', + }, + 'librispeech_conformer': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerWorkload', + }, + 'librispeech_conformer_attention_temperature': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload', + }, + 'librispeech_conformer_layernorm': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload', + }, + 'librispeech_conformer_gelu': { + 'workload_path': 'librispeech_conformer/librispeech', + 'workload_class_name': 'LibriSpeechConformerGeluWorkload', + }, + 'librispeech_deepspeech': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', + }, + 'librispeech_deepspeech_tanh': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', + }, + 'librispeech_deepspeech_no_resnet': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', + }, + 'librispeech_deepspeech_norm_and_spec_aug': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', + }, + 'mnist': { + 'workload_path': 'mnist/mnist', + 'workload_class_name': 'MnistWorkload', + }, + 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'}, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgGeluWorkload', + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgSiluWorkload', + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload', + }, + 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, + 'wmt_post_ln': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadPostLN', + }, + 'wmt_attention_temp': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadAttentionTemp', + }, + 'wmt_glu_tanh': { + 'workload_path': 'wmt/wmt', + 'workload_class_name': 'WmtWorkloadGLUTanH', + }, } BASE_WORKLOADS = [ - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'ogbg', - 'wmt' + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'ogbg', + 'wmt', ] @@ -171,10 +173,12 @@ def convert_filepath_to_module(path: str): return base.replace('/', '.') -def import_workload(workload_path: str, - workload_class_name: str, - return_class=False, - workload_init_kwargs=None) -> spec.Workload: +def import_workload( + workload_path: str, + workload_class_name: str, + return_class=False, + workload_init_kwargs=None, +) -> spec.Workload: """Import and add the workload to the registry. This importlib loading is nice to have because it allows runners to avoid @@ -206,9 +210,10 @@ def import_workload(workload_path: str, break if workload_class is None: raise ValueError( - f'Could not find member {workload_class_name} in {workload_path}. ' - 'Make sure the Workload class is spelled correctly and defined in ' - 'the top scope of the module.') + f'Could not find member {workload_class_name} in {workload_path}. ' + 'Make sure the Workload class is spelled correctly and defined in ' + 'the top scope of the module.' + ) if return_class: return workload_class return workload_class(**workload_init_kwargs) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index efe923dbe..e110930cd 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -72,8 +72,7 @@ from torchvision.datasets import CIFAR10 from algoperf.workloads.wmt import tokenizer -from algoperf.workloads.wmt.input_pipeline import \ - normalize_feature_names +from algoperf.workloads.wmt.input_pipeline import normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer @@ -101,84 +100,96 @@ FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz' flags.DEFINE_boolean( - 'interactive_deletion', - True, - 'If true, user will be prompted before any files are deleted. If false, no ' - 'files will be deleted.') + 'interactive_deletion', + True, + 'If true, user will be prompted before any files are deleted. If false, no ' + 'files will be deleted.', +) flags.DEFINE_boolean( - 'all', - False, - 'Whether or not to download all datasets. If false, can download some ' - 'combination of datasets by setting the individual dataset flags below.') - -flags.DEFINE_boolean('criteo1tb', - False, - 'If --all=false, whether or not to download Criteo 1TB.') -flags.DEFINE_boolean('cifar', - False, - 'If --all=false, whether or not to download CIFAR-10.') -flags.DEFINE_boolean('fastmri', - False, - 'If --all=false, whether or not to download FastMRI.') -flags.DEFINE_boolean('imagenet', - False, - 'If --all=false, whether or not to download Imagenet.') -flags.DEFINE_boolean('librispeech', - False, - 'If --all=false, whether or not to download LibriSpeech.') -flags.DEFINE_boolean('mnist', - False, - 'If --all=false, whether or not to download MNIST.') -flags.DEFINE_boolean('ogbg', - False, - 'If --all=false, whether or not to download OGBG.') -flags.DEFINE_boolean('wmt', - False, - 'If --all=false, whether or not to download WMT.') + 'all', + False, + 'Whether or not to download all datasets. If false, can download some ' + 'combination of datasets by setting the individual dataset flags below.', +) + +flags.DEFINE_boolean( + 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.' +) +flags.DEFINE_boolean( + 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.' +) +flags.DEFINE_boolean( + 'fastmri', False, 'If --all=false, whether or not to download FastMRI.' +) +flags.DEFINE_boolean( + 'imagenet', False, 'If --all=false, whether or not to download Imagenet.' +) +flags.DEFINE_boolean( + 'librispeech', + False, + 'If --all=false, whether or not to download LibriSpeech.', +) +flags.DEFINE_boolean( + 'mnist', False, 'If --all=false, whether or not to download MNIST.' +) +flags.DEFINE_boolean( + 'ogbg', False, 'If --all=false, whether or not to download OGBG.' +) +flags.DEFINE_boolean( + 'wmt', False, 'If --all=false, whether or not to download WMT.' +) flags.DEFINE_string( - 'data_dir', - '~/data', - 'The path to the folder where datasets should be downloaded.') + 'data_dir', + '~/data', + 'The path to the folder where datasets should be downloaded.', +) flags.DEFINE_string( - 'temp_dir', - '/tmp/mlcommons', - 'A local path to a folder where temp files can be downloaded.') + 'temp_dir', + '/tmp/mlcommons', + 'A local path to a folder where temp files can be downloaded.', +) flags.DEFINE_string( - 'imagenet_train_url', - None, - 'Only necessary if you want this script to `wget` the ImageNet train ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'imagenet_train_url', + None, + 'Only necessary if you want this script to `wget` the ImageNet train ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'imagenet_val_url', - None, - 'Only necessary if you want this script to `wget` the ImageNet validation ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'imagenet_val_url', + None, + 'Only necessary if you want this script to `wget` the ImageNet validation ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'fastmri_knee_singlecoil_train_url', - None, - 'Only necessary if you want this script to `wget` the FastMRI train ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'fastmri_knee_singlecoil_train_url', + None, + 'Only necessary if you want this script to `wget` the FastMRI train ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'fastmri_knee_singlecoil_val_url', - None, - 'Only necessary if you want this script to `wget` the FastMRI validation ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'fastmri_knee_singlecoil_val_url', + None, + 'Only necessary if you want this script to `wget` the FastMRI validation ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_string( - 'fastmri_knee_singlecoil_test_url', - None, - 'Only necessary if you want this script to `wget` the FastMRI test ' - 'split. If not, you can supply the path to --data_dir in ' - 'submission_runner.py.') + 'fastmri_knee_singlecoil_test_url', + None, + 'Only necessary if you want this script to `wget` the FastMRI test ' + 'split. If not, you can supply the path to --data_dir in ' + 'submission_runner.py.', +) flags.DEFINE_integer( - 'num_decompression_threads', - 8, - 'The number of threads to use in parallel when decompressing.') + 'num_decompression_threads', + 8, + 'The number of threads to use in parallel when decompressing.', +) flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.') @@ -186,7 +197,7 @@ FLAGS = flags.FLAGS -os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' def _maybe_mkdir(d): @@ -198,8 +209,10 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion): if not interactive_deletion: return files_for_deletion = '\n'.join(paths) - logging.info('\n\n\nWARNING: the following temp files will be DELETED:' - f'\n{files_for_deletion}') + logging.info( + '\n\n\nWARNING: the following temp files will be DELETED:' + f'\n{files_for_deletion}' + ) delete_str = input('Confirm deletion? [y/N]: ') if delete_str.lower() == 'y': del_cmd = 'rm ' + ' '.join(f'"{s}"' for s in paths) @@ -225,8 +238,9 @@ def _download_url(url, data_dir, name=None): if os.path.exists(file_path): while True: - overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format( - file_path)).lower() + overwrite = input( + 'File already exists {}.\n Overwrite? (Y/n)'.format(file_path) + ).lower() if overwrite in ['y', 'n']: break logging.info('Invalid response. Try again.') @@ -240,17 +254,18 @@ def _download_url(url, data_dir, name=None): progress_bar.update(chunk_size_in_mib) f.write(chunk) progress_bar.close() - if (progress_bar.total != 0 and progress_bar.n != progress_bar.total): + if progress_bar.total != 0 and progress_bar.n != progress_bar.total: raise RuntimeError( - ('Download corrupted, size {n} MiB from {url} does not match ' - 'expected size {size} MiB').format( - url=url, n=progress_bar.n, size=progress_bar.total)) + ( + 'Download corrupted, size {n} MiB from {url} does not match ' + 'expected size {size} MiB' + ).format(url=url, n=progress_bar.n, size=progress_bar.total) + ) -def download_criteo1tb(data_dir, - tmp_dir, - num_decompression_threads, - interactive_deletion): +def download_criteo1tb( + data_dir, tmp_dir, num_decompression_threads, interactive_deletion +): criteo_dir = os.path.join(data_dir, 'criteo1tb') tmp_criteo_dir = os.path.join(tmp_dir, 'criteo1tb') _maybe_mkdir(criteo_dir) @@ -258,47 +273,56 @@ def download_criteo1tb(data_dir, # Forked from # https://github.com/iamleot/transferwee/blob/master/transferwee.py. - user_agent = ('Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) ' - 'Gecko/20100101 Firefox/102.0') + user_agent = ( + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) ' + 'Gecko/20100101 Firefox/102.0' + ) criteo_wetransfer_url = ( - 'https://criteo.wetransfer.com/downloads/' - '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2') + 'https://criteo.wetransfer.com/downloads/' + '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2' + ) _, _, transfer_id, security_hash = urllib.parse.urlparse( - criteo_wetransfer_url).path.split('/') + criteo_wetransfer_url + ).path.split('/') session = requests.Session() - session.headers.update({ + session.headers.update( + { 'User-Agent': user_agent, 'x-requested-with': 'XMLHttpRequest', - }) + } + ) r = session.get('https://wetransfer.com/') m = re.search('name="csrf-token" content="([^"]+)"', r.text) if m: session.headers.update({'x-csrf-token': m.group(1)}) get_url_request = session.post( - f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download', - json={ - 'intent': 'entire_transfer', - 'security_hash': security_hash, - }) + f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download', + json={ + 'intent': 'entire_transfer', + 'security_hash': security_hash, + }, + ) session.close() download_url = get_url_request.json().get('direct_link') logging.info(f'Downloading ~342GB Criteo 1TB data .zip file:\n{download_url}') download_request = requests.get( # pylint: disable=missing-timeout - download_url, - headers={'User-Agent': user_agent}, - stream=True) + download_url, headers={'User-Agent': user_agent}, stream=True + ) all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip') if not FLAGS.skip_download: download = True if os.path.exists(all_days_zip_filepath): while True: - overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format( - all_days_zip_filepath)).lower() + overwrite = input( + 'File already exists {}.\n Overwrite? (Y/n)'.format( + all_days_zip_filepath + ) + ).lower() if overwrite in ['y', 'n']: break logging.info('Invalid response. Try again.') @@ -324,8 +348,10 @@ def download_criteo1tb(data_dir, input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz') gz_paths.append(input_path) unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') - unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' - f'"{unzipped_path}"') + unzip_cmd = ( + f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' + f'"{unzipped_path}"' + ) logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}') processes.append(subprocess.Popen(unzip_cmd, shell=True)) for p in processes: @@ -341,8 +367,7 @@ def download_criteo1tb(data_dir, unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzipped_paths.append(unzipped_path) split_path = os.path.join(criteo_dir, f'day_{day}_') - split_cmd = ('split -a 2 -d -l 5000000 ' - f'"{unzipped_path}" "{split_path}"') + split_cmd = f'split -a 2 -d -l 5000000 "{unzipped_path}" "{split_path}"' logging.info(f'Running Criteo 1TB split command:\n{split_cmd}') batch_processes.append(subprocess.Popen(split_cmd, shell=True)) for p in batch_processes: @@ -362,45 +387,50 @@ def download_cifar(data_dir, framework): def extract_filename_from_url(url, start_str='knee', end_str='.xz'): - """ The url filenames are sometimes couched within a urldefense+aws access id + """The url filenames are sometimes couched within a urldefense+aws access id etc. string. Unfortunately querying the content disposition in requests fails (not provided)... so fast search is done here within the url. - """ + """ failure = -1 start = url.find(start_str) end = url.find(end_str) if failure in (start, end): raise ValueError( - f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}') + f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}' + ) end += len(end_str) # make it inclusive return url[start:end] -def download_fastmri(data_dir, - fastmri_train_url, - fastmri_val_url, - fastmri_test_url): +def download_fastmri( + data_dir, fastmri_train_url, fastmri_val_url, fastmri_test_url +): data_dir = os.path.join(data_dir, 'fastmri') # Download fastmri train dataset knee_train_filename = extract_filename_from_url(fastmri_train_url) logging.info( - 'Downloading fastmri train dataset from {}'.format(fastmri_train_url)) + 'Downloading fastmri train dataset from {}'.format(fastmri_train_url) + ) _download_url( - url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename) + url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename + ) # Download fastmri val dataset knee_val_filename = extract_filename_from_url(fastmri_val_url) logging.info( - 'Downloading fastmri val dataset from {}'.format(fastmri_val_url)) + 'Downloading fastmri val dataset from {}'.format(fastmri_val_url) + ) _download_url(url=fastmri_val_url, data_dir=data_dir, name=knee_val_filename) # Download fastmri test dataset knee_test_filename = extract_filename_from_url(fastmri_test_url) logging.info( - 'Downloading fastmri test dataset from {}'.format(fastmri_test_url)) + 'Downloading fastmri test dataset from {}'.format(fastmri_test_url) + ) _download_url( - url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename) + url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename + ) return data_dir @@ -432,18 +462,18 @@ def setup_fastmri(data_dir): # Rename folders to match what the workload expects os.rename( - os.path.join(data_dir, "singlecoil_train"), - os.path.join(data_dir, "knee_singlecoil_train"), + os.path.join(data_dir, 'singlecoil_train'), + os.path.join(data_dir, 'knee_singlecoil_train'), ) os.rename( - os.path.join(data_dir, "singlecoil_val"), - os.path.join(data_dir, "knee_singlecoil_val"), + os.path.join(data_dir, 'singlecoil_val'), + os.path.join(data_dir, 'knee_singlecoil_val'), ) os.rename( - os.path.join(data_dir, "singlecoil_test"), - os.path.join(data_dir, "knee_singlecoil_test"), + os.path.join(data_dir, 'singlecoil_test'), + os.path.join(data_dir, 'knee_singlecoil_test'), ) - logging.info("Set up fastMRI dataset complete") + logging.info('Set up fastMRI dataset complete') def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): @@ -456,26 +486,32 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): # been moved to the manual_download_dir. # Get paths in manual_download_dir. imagenet_jax_data_dir = os.path.join(data_dir, 'jax') - manual_download_dir = os.path.join(imagenet_jax_data_dir, - 'downloads', - 'manual') - imagenet_train_download_filepath = os.path.join(manual_download_dir, - IMAGENET_TRAIN_TAR_FILENAME) - imagenet_val_download_filepath = os.path.join(manual_download_dir, - IMAGENET_VAL_TAR_FILENAME) + manual_download_dir = os.path.join( + imagenet_jax_data_dir, 'downloads', 'manual' + ) + imagenet_train_download_filepath = os.path.join( + manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME + ) + imagenet_val_download_filepath = os.path.join( + manual_download_dir, IMAGENET_VAL_TAR_FILENAME + ) # Download imagenet train dataset if not os.path.exists(imagenet_train_filepath) and not os.path.exists( - imagenet_train_download_filepath): + imagenet_train_download_filepath + ): logging.info( - 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)) + 'Downloading imagenet train dataset from {}'.format(imagenet_train_url) + ) _download_url(url=imagenet_train_url, data_dir=data_dir) # Download imagenet val dataset if not os.path.exists(imagenet_val_filepath) and not os.path.exists( - imagenet_val_download_filepath): - logging.info('Downloading imagenet validation dataset from {}'.format( - imagenet_val_url)) + imagenet_val_download_filepath + ): + logging.info( + 'Downloading imagenet validation dataset from {}'.format(imagenet_val_url) + ) _download_url(url=imagenet_val_url, data_dir=data_dir) # Download imagenet test set @@ -501,31 +537,40 @@ def setup_imagenet_jax(data_dir): # Setup jax dataset dir imagenet_jax_data_dir = os.path.join(data_dir, 'jax') - manual_download_dir = os.path.join(imagenet_jax_data_dir, - 'downloads', - 'manual') + manual_download_dir = os.path.join( + imagenet_jax_data_dir, 'downloads', 'manual' + ) os.makedirs(manual_download_dir, exist_ok=True) # Copy tar file into jax/downloads/manual logging.info('Checking if tar files already exists in jax/downloads/manual.') if not os.path.exists( - os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(train_tar_file_path, - manual_download_dir)) + os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(train_tar_file_path, manual_download_dir) + ) shutil.move(train_tar_file_path, manual_download_dir) if not os.path.exists( - os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(val_tar_file_path, - manual_download_dir)) + os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(val_tar_file_path, manual_download_dir) + ) shutil.move(val_tar_file_path, manual_download_dir) if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')): - logging.info('Moving imagenet_v2 to {}'.format( - os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))) - shutil.move(test_dir_path, - os.path.join(imagenet_jax_data_dir, 'imagenet_v2')) + logging.info( + 'Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_jax_data_dir, 'imagenet_v2') + ) + ) + shutil.move( + test_dir_path, os.path.join(imagenet_jax_data_dir, 'imagenet_v2') + ) logging.info('Preparing imagenet data.') ds_builder = tfds.builder( - 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir)) + 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir) + ) ds_builder.download_and_prepare() logging.info('Set up imagenet dataset for jax framework complete') @@ -539,14 +584,18 @@ def setup_imagenet_pytorch(data_dir): manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual') if not os.path.exists(train_tar_file_path): if os.path.exists( - os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)): - train_tar_file_path = os.path.join(manual_download_dir, - IMAGENET_TRAIN_TAR_FILENAME) + os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME) + ): + train_tar_file_path = os.path.join( + manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME + ) if not os.path.exists(val_tar_file_path): if os.path.exists( - os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)): - val_tar_file_path = os.path.join(manual_download_dir, - IMAGENET_VAL_TAR_FILENAME) + os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME) + ): + val_tar_file_path = os.path.join( + manual_download_dir, IMAGENET_VAL_TAR_FILENAME + ) # Setup pytorch dataset dir imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch') @@ -557,56 +606,68 @@ def setup_imagenet_pytorch(data_dir): # Move tar files and imagenet_v2 into pytorch directory if not os.path.exists( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(train_tar_file_path, - imagenet_pytorch_data_dir)) + os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(train_tar_file_path, imagenet_pytorch_data_dir) + ) shutil.move(train_tar_file_path, imagenet_pytorch_data_dir) if not os.path.exists( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME)): - logging.info('Moving {} to {}'.format(val_tar_file_path, - imagenet_pytorch_data_dir)) + os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME) + ): + logging.info( + 'Moving {} to {}'.format(val_tar_file_path, imagenet_pytorch_data_dir) + ) shutil.move(val_tar_file_path, imagenet_pytorch_data_dir) if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')): - logging.info('Moving imagenet_v2 to {}'.format( - os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))) - shutil.move(test_dir_path, - os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')) + logging.info( + 'Moving imagenet_v2 to {}'.format( + os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2') + ) + ) + shutil.move( + test_dir_path, os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2') + ) # Extract train data\ logging.info('Extracting imagenet train data') extract( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'train'), - mode='r:') + os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME), + os.path.join(imagenet_pytorch_data_dir, 'train'), + mode='r:', + ) train_tar_filenames = os.listdir( - os.path.join(imagenet_pytorch_data_dir, 'train')) + os.path.join(imagenet_pytorch_data_dir, 'train') + ) for tar_filename in train_tar_filenames: if tar_filename.endswith('.tar'): dir_name = tar_filename[:-4] extract( - os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename), - os.path.join(imagenet_pytorch_data_dir, 'train', dir_name), - mode='r:') + os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename), + os.path.join(imagenet_pytorch_data_dir, 'train', dir_name), + mode='r:', + ) # Extract val data logging.info('Extracting imagenet val data') extract( - os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME), - os.path.join(imagenet_pytorch_data_dir, 'val'), - mode='r:') + os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME), + os.path.join(imagenet_pytorch_data_dir, 'val'), + mode='r:', + ) valprep_command = [ - 'wget', - '-qO-', - 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh' + 'wget', + '-qO-', + 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh', ] valprep_download = subprocess.Popen(valprep_command, stdout=subprocess.PIPE) - valprep_process = subprocess.Popen(['bash'], - stdin=valprep_download.stdout, - cwd=os.path.expanduser( - os.path.join(imagenet_pytorch_data_dir, - 'val'))) + valprep_process = subprocess.Popen( + ['bash'], + stdin=valprep_download.stdout, + cwd=os.path.expanduser(os.path.join(imagenet_pytorch_data_dir, 'val')), + ) valprep_download.stdout.close() valprep_process.communicate() logging.info('Set up imagenet dataset for pytorch framework complete') @@ -614,8 +675,8 @@ def setup_imagenet_pytorch(data_dir): def download_imagenet_v2(data_dir): tfds.builder( - 'imagenet_v2/matched-frequency:3.0.0', - data_dir=data_dir).download_and_prepare() + 'imagenet_v2/matched-frequency:3.0.0', data_dir=data_dir + ).download_and_prepare() def download_librispeech(data_dir, tmp_dir): @@ -634,41 +695,46 @@ def download_librispeech(data_dir, tmp_dir): if split == 'test' and version == 'other': continue wget_cmd = ( - f'wget --directory-prefix={tmp_librispeech_dir} ' - f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz') + f'wget --directory-prefix={tmp_librispeech_dir} ' + f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz' + ) subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz') subprocess.Popen( - f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', - shell=True).communicate() + f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True + ).communicate() tars = [ - 'raw-metadata.tar.gz', - 'train-clean-100.tar.gz', - 'train-clean-360.tar.gz', - 'train-other-500.tar.gz', + 'raw-metadata.tar.gz', + 'train-clean-100.tar.gz', + 'train-clean-360.tar.gz', + 'train-other-500.tar.gz', ] for tar_filename in tars: - wget_cmd = (f'wget --directory-prefix={tmp_librispeech_dir} ' - f'http://www.openslr.org/resources/12/{tar_filename}') + wget_cmd = ( + f'wget --directory-prefix={tmp_librispeech_dir} ' + f'http://www.openslr.org/resources/12/{tar_filename}' + ) subprocess.Popen(wget_cmd, shell=True).communicate() tar_path = os.path.join(tmp_librispeech_dir, tar_filename) subprocess.Popen( - f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', - shell=True).communicate() + f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True + ).communicate() tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): librispeech_tokenizer.run( - train=True, - input_dir=extracted_data_dir, - tokenizer_vocab_path=tokenizer_vocab_path) + train=True, + input_dir=extracted_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path, + ) librispeech_preprocess.run( - input_dir=extracted_data_dir, - output_dir=final_data_dir, - tokenizer_vocab_path=tokenizer_vocab_path) + input_dir=extracted_data_dir, + output_dir=final_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path, + ) def download_mnist(data_dir): @@ -691,12 +757,14 @@ def download_wmt(data_dir): if ds_name == 'wmt17_translate/de-en:1.0.0': ds = dataset_builder.as_dataset(split='train', shuffle_files=False) ds = ds.map( - functools.partial(normalize_feature_names, dataset_builder.info), - num_parallel_calls=tf.data.AUTOTUNE) + functools.partial(normalize_feature_names, dataset_builder.info), + num_parallel_calls=tf.data.AUTOTUNE, + ) # Tokenize data. vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model') tokenizer.train_tokenizer( - ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7) + ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7 + ) def main(_): @@ -715,10 +783,9 @@ def main(_): if FLAGS.all or FLAGS.criteo1tb: logging.info('Downloading criteo1tb...') - download_criteo1tb(data_dir, - tmp_dir, - num_decompression_threads, - FLAGS.interactive_deletion) + download_criteo1tb( + data_dir, tmp_dir, num_decompression_threads, FLAGS.interactive_deletion + ) if FLAGS.all or FLAGS.mnist: logging.info('Downloading MNIST...') @@ -730,19 +797,24 @@ def main(_): knee_singlecoil_train_url = FLAGS.fastmri_knee_singlecoil_train_url knee_singlecoil_val_url = FLAGS.fastmri_knee_singlecoil_val_url knee_singlecoil_test_url = FLAGS.fastmri_knee_singlecoil_test_url - if None in (knee_singlecoil_train_url, - knee_singlecoil_val_url, - knee_singlecoil_test_url): + if None in ( + knee_singlecoil_train_url, + knee_singlecoil_val_url, + knee_singlecoil_test_url, + ): raise ValueError( - 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url ' - 'to download the FastMRI dataset.\nSign up for the URLs at ' - 'https://fastmri.med.nyu.edu/.') + 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url ' + 'to download the FastMRI dataset.\nSign up for the URLs at ' + 'https://fastmri.med.nyu.edu/.' + ) if not FLAGS.skip_download: - download_fastmri(data_dir, - knee_singlecoil_train_url, - knee_singlecoil_val_url, - knee_singlecoil_test_url) + download_fastmri( + data_dir, + knee_singlecoil_train_url, + knee_singlecoil_val_url, + knee_singlecoil_test_url, + ) logging.info('fastMRI download completed. Extracting...') setup_fastmri(data_dir) @@ -754,12 +826,13 @@ def main(_): imagenet_val_url = FLAGS.imagenet_val_url if imagenet_train_url is None or imagenet_val_url is None: raise ValueError( - 'Must provide both --imagenet_{train,val}_url to download the ' - 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.') + 'Must provide both --imagenet_{train,val}_url to download the ' + 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.' + ) if FLAGS.framework is None: raise ValueError( - 'Please specify either jax or pytorch framework through framework ' - 'flag.') + 'Please specify either jax or pytorch framework through framework flag.' + ) if not FLAGS.skip_download: logging.info('Downloading ImageNet...') download_imagenet(data_dir, imagenet_train_url, imagenet_val_url) diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index a8c5cae1d..1c216db46 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -4,16 +4,15 @@ import multiprocessing.dummy import os -from os.path import exists import sys import threading import time -from absl import logging import numpy as np import pandas as pd -from pydub import AudioSegment import tensorflow as tf +from absl import logging +from pydub import AudioSegment from datasets import librispeech_tokenizer @@ -28,17 +27,18 @@ # taken from TFDS page for librispeech dataset : # https://www.tensorflow.org/datasets/catalog/librispeech librispeech_example_counts = { - 'train-clean-100': 28539, - 'train-clean-360': 104014, - 'train-other-500': 148688, - 'test-clean': 2620, # 'test-other': 2939, - 'dev-clean': 2703, - 'dev-other': 2864, + 'train-clean-100': 28539, + 'train-clean-360': 104014, + 'train-other-500': 148688, + 'test-clean': 2620, # 'test-other': 2939, + 'dev-clean': 2703, + 'dev-other': 2864, } class Counter: """A threadsafe counter.""" + lock = threading.Lock() value = 0 @@ -56,10 +56,12 @@ def report_progress(count, total, start_time): now = time.time() size = 50 filled = int(round(size * count / float(total))) - percent = round(100. * count / float(total), 1) - bar = "-" * filled + "." * (size - filled) - sys.stdout.write("[%s] %d%% (%d of %d) %.2f sample/sec\r" % - (bar, percent, count, total, count / (now - start_time))) + percent = round(100.0 * count / float(total), 1) + bar = '-' * filled + '.' * (size - filled) + sys.stdout.write( + '[%s] %d%% (%d of %d) %.2f sample/sec\r' + % (bar, percent, count, total, count / (now - start_time)) + ) sys.stdout.flush() @@ -72,17 +74,20 @@ def process(index): data_folder, speaker_folder, chapter_folder = index utterance_ids = [] - trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/' - f'{speaker_folder}-{chapter_folder}.trans.txt') + trans_file = ( + f'{data_folder}/{speaker_folder}/{chapter_folder}/' + f'{speaker_folder}-{chapter_folder}.trans.txt' + ) if not exists(trans_file): skipped.inc() return utterance_ids with open(trans_file, 'r', encoding='UTF-8') as f: - for l in f: - utt, trans = l.strip().split(' ', maxsplit=1) + for line in f: + utt, trans = line.strip().split(' ', maxsplit=1) audio_path = ( - f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac') + f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac' + ) if not os.path.isfile(audio_path): skipped.inc() @@ -105,9 +110,11 @@ def process(index): np.save('{}/{}/{}_targets.npy'.format(out_folder, split, utt), targets) finished.inc() - report_progress(finished.val() + skipped.val(), - librispeech_example_counts[split], - start_time) + report_progress( + finished.val() + skipped.val(), + librispeech_example_counts[split], + start_time, + ) utterance_ids.append(utt) return utterance_ids @@ -126,10 +133,12 @@ def process(index): end_time = time.time() elapsed_time = end_time - start_time - print(' \n time taken to preprocess split : ', - split, - ' = ', - time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) + print( + ' \n time taken to preprocess split : ', + split, + ' = ', + time.strftime('%H:%M:%S', time.gmtime(elapsed_time)), + ) final_count = finished.val() + skipped.val() return pd.DataFrame(file_trans, columns=['id']), final_count @@ -147,12 +156,12 @@ def run(input_dir, output_dir, tokenizer_vocab_path): os.makedirs(output_dir, exist_ok=True) subset_list = [ - 'train-clean-100', - 'train-clean-360', - 'train-other-500', - 'dev-clean', - 'dev-other', - 'test-clean', # 'test-other', + 'train-clean-100', + 'train-clean-360', + 'train-other-500', + 'dev-clean', + 'dev-other', + 'test-clean', # 'test-other', ] for subset in subset_list: logging.info('Processing split = %s...', subset) @@ -160,10 +169,14 @@ def run(input_dir, output_dir, tokenizer_vocab_path): out_dir = os.path.join(output_dir, subset) os.makedirs(out_dir, exist_ok=True) example_ids, num_entries = preprocess_data( - in_dir, output_dir, tokenizer, subset) + in_dir, output_dir, tokenizer, subset + ) if num_entries != librispeech_example_counts[subset]: - raise ValueError('Preprocessed dataframe final count not equal to ' - 'expected count: {} vs expected {}'.format( - num_entries, librispeech_example_counts[subset])) + raise ValueError( + 'Preprocessed dataframe final count not equal to ' + 'expected count: {} vs expected {}'.format( + num_entries, librispeech_example_counts[subset] + ) + ) example_ids.to_csv(os.path.join(output_dir, f'{subset}.csv')) diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py index 2f559752a..d566d5716 100644 --- a/datasets/librispeech_tokenizer.py +++ b/datasets/librispeech_tokenizer.py @@ -8,10 +8,10 @@ import tempfile from typing import Dict -from absl import logging import sentencepiece as spm import tensorflow as tf import tensorflow_text as tftxt +from absl import logging gfile = tf.io.gfile copy = tf.io.gfile.copy @@ -24,7 +24,8 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): char_count = 0 with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + delete=False, prefix='/tmp/ds_chars' + ) as outfp: for split in splits: data_folder = data_folder + '/' + split for _, speaker_folder in enumerate(os.listdir(data_folder)): @@ -32,14 +33,16 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): break for chapter_folder in os.listdir(f'{data_folder}/{speaker_folder}'): - trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/' - f'{speaker_folder}-{chapter_folder}.trans.txt') + trans_file = ( + f'{data_folder}/{speaker_folder}/{chapter_folder}/' + f'{speaker_folder}-{chapter_folder}.trans.txt' + ) if not exists(trans_file): logging.info('path does not exist -> %s', trans_file) continue with open(trans_file, 'r', encoding='UTF-8') as f: - for l in f: - _, line = l.strip().split(' ', maxsplit=1) + for lines in f: + _, line = lines.strip().split(' ', maxsplit=1) line = line + '\n' char_count += len(line) if char_count > maxchars: @@ -50,13 +53,15 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)): return outfp -def train_tokenizer(data_dir: str, - splits, - vocab_size: int = 1024, - model_path: str = 'spm_model.vocab', - maxchars: int = int(1e7), - model_type: str = 'unigram', - character_coverage: float = 1.0): +def train_tokenizer( + data_dir: str, + splits, + vocab_size: int = 1024, + model_path: str = 'spm_model.vocab', + maxchars: int = int(1e7), + model_type: str = 'unigram', + character_coverage: float = 1.0, +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -77,15 +82,18 @@ def train_tokenizer(data_dir: str, charfile = dump_chars_for_training(data_dir, splits, maxchars=maxchars) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + delete=False, prefix='/tmp/sp_tmp' + ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ + argstr = ' '.join( + [ f'--input={charfile.name}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', - ]) + ] + ) spm.SentencePieceTrainer.Train(argstr) copy_rename_path = abs_model_path + '.rntmp' @@ -104,7 +112,8 @@ def load_tokenizer(model_filepath): with gfile.GFile(model_filepath, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=False, add_eos=True, reverse=False) + model=sp_model, add_bos=False, add_eos=True, reverse=False + ) return sp_tokenizer @@ -123,8 +132,9 @@ def run(train, input_dir, tokenizer_vocab_path): detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8') logging.info('Original input = %s', test_input) - logging.info('Output after after tokenizing and detokenizing = %s', - detokenized) + logging.info( + 'Output after after tokenizing and detokenizing = %s', detokenized + ) if detokenized == test_input: logging.info('Tokenizer working correctly!') diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..9926b0542 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ - libbz2-dev \ liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 @@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip - # Install Algorithmic efficiency repo RUN pip install --upgrade pip @@ -71,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_cpu]'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu]' \ + && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \ else \ echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ && exit 1 ; \ diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,27 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; + p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf" ]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py index 48c521009..a816eb5c2 100644 --- a/docker/scripts/singularity_converter.py +++ b/docker/scripts/singularity_converter.py @@ -15,26 +15,27 @@ from spython.main.parse.writers import get_writer # globals -ENTRY_POINT = "/bin/bash" # seems to be a good default +ENTRY_POINT = '/bin/bash' # seems to be a good default FORCE = False # seems to be a good default # -parser = argparse.ArgumentParser(description="Custom Singularity converter") +parser = argparse.ArgumentParser(description='Custom Singularity converter') parser.add_argument( - "-i", "--input", type=str, help="Docker input path", default="Dockerfile") + '-i', '--input', type=str, help='Docker input path', default='Dockerfile' +) parser.add_argument( - "-o", - "--output", - type=str, - help="Singularity output path", - default="Singularity.def", + '-o', + '--output', + type=str, + help='Singularity output path', + default='Singularity.def', ) args = parser.parse_args() INPUT_DOCKERFILE_PATH = args.input OUTPUT_SINGULARITY_PATH = args.output # create Docker parser and Singularity writer -parser = get_parser("docker") -writer = get_writer("singularity") +parser = get_parser('docker') +writer = get_writer('singularity') # parse Dockerfile into Singularity and suppress %files commands recipeParser = parser(INPUT_DOCKERFILE_PATH) @@ -44,5 +45,5 @@ # convert to string and save to output file result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE) -with open(OUTPUT_SINGULARITY_PATH, "w") as f: +with open(OUTPUT_SINGULARITY_PATH, 'w') as f: f.write(result) diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 7778030dc..e9918e14c 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -179,13 +179,13 @@ docker run -t -d \ To find the container IDs of running containers ```bash -docker ps +docker ps ``` To see the logging output ```bash -docker logs +docker logs ``` To enter a bash session in the container @@ -209,7 +209,7 @@ docker run -t -d \ --gpus all \ --ipc=host \ \ ---keep_container_alive true +--keep_container_alive true ``` ## Submitting PRs @@ -222,38 +222,26 @@ We run tests with GitHub Actions, configured in the [.github/workflows](.github/ ### Style Testing -We run yapf and linting tests on PRs. You can view and fix offending errors with these instructions. - +We run formatting and linting tests via ruff on PRs. You can view and fix offending errors with these instructions. To run the below commands, use the versions installed via `pip install -e '.[dev]'`. -To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!): +To check whether your code is **formatted** correctly, run the following: ```bash -yapf -i -r -vv -p algoperf datasets prize_qualification_baselines reference_algorithms tests *.py +ruff format --check ``` -To sort all import orderings, run the following: +To automatically fix formatting errors you can run `ruff format`, without the `--check` flag. +(**WARNING**: this will edit your code, so it is suggested to make a git commit first!) -```bash -isort . -``` - -To just print out all offending import orderings, run the following: +To check whether your code is **linted** correctly, run the following: ```bash -isort . --check --diff +ruff check ``` -To print out all offending pylint issues, run the following: - -```bash -pylint algoperf -pylint datasets -pylint prize_qualification_baselines -pylint reference_algorithms -pylint submission_runner.py -pylint tests -``` +To automatically fix linting errors you can run `ruff check --fix`, with the additional `--fix` flag. +(**WARNING**: this will edit your code, so it is suggested to make a git commit first!) ### Unit and Integration Tests @@ -270,9 +258,9 @@ To run a regression test: 1. Build and upload latest Docker images from dev branch. - ```bash - bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev - ``` + ```bash + bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev + ``` 2. Turn on the self-hosted runner. 3. Run the self-hosted runner application for the runner to accept jobs. diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index c451a18ac..62161b3d5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -1,26 +1,24 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -30,15 +28,14 @@ # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -73,19 +70,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -124,7 +124,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -132,6 +133,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -140,7 +142,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -156,11 +159,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -170,101 +175,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -281,37 +300,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -351,14 +376,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index b8ac10f33..9752aef33 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -1,26 +1,24 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -30,15 +28,14 @@ # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -73,19 +70,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -124,7 +124,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -132,6 +133,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -140,7 +142,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -156,11 +159,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -170,101 +175,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -281,37 +300,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -351,14 +376,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index a2f9fb4c5..f6c2faa9d 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -3,13 +3,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -21,33 +19,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +54,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -67,7 +65,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -76,9 +75,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -107,51 +106,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -189,54 +194,59 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -248,26 +258,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -280,7 +294,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -289,31 +304,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -353,14 +375,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index a37b0d341..68ff30b2a 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -3,13 +3,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -21,33 +19,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +54,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -67,7 +65,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -76,9 +75,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -107,51 +106,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -189,54 +194,59 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -248,26 +258,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -280,7 +294,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -289,31 +304,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -353,14 +375,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index 199f77041..ad68372c6 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -1,53 +1,47 @@ [ - { - "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.2, - "learning_rate": 0.0008445074561975979, - "one_minus_beta1": 0.11042418465, - "beta2": 0.9978504782314613, - "weight_decay": 0.08135402759553023, - "warmup_factor": 0.05 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.0, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.0, - "learning_rate": 0.004958460849689891, - "one_minus_beta1": 0.13625575743, - "beta2": 0.6291854735396584, - "weight_decay": 0.1147386261512052, - "warmup_factor": 0.02 - }, - { - "dropout_rate": 0.1, - "label_smoothing": 0.0, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "label_smoothing": 0.0, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } ] - - - - - - diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 78c3b5b3e..f61a7bdcd 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -1,53 +1,50 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -82,19 +79,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -133,7 +133,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -141,6 +142,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -149,7 +151,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -165,11 +168,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -182,101 +187,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters['warmup_factor'] * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters['learning_rate'], - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters['learning_rate'], + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps) + init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters['one_minus_beta1'], - b2=hyperparameters['beta2'], - eps=1e-8, - weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters['one_minus_beta1'], + b2=hyperparameters['beta2'], + eps=1e-8, + weight_decay=hyperparameters['weight_decay'], + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -296,37 +315,43 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -366,14 +391,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index ffe854a0e..130ebdabe 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -1,53 +1,50 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -82,19 +79,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -133,7 +133,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -141,6 +142,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -149,7 +151,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -165,11 +168,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -182,101 +187,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters['warmup_factor'] * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters['learning_rate'], - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters['learning_rate'], + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps) + init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters['one_minus_beta1'], - b2=hyperparameters['beta2'], - eps=1e-8, - weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters['one_minus_beta1'], + b2=hyperparameters['beta2'], + eps=1e-8, + weight_decay=hyperparameters['weight_decay'], + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -296,37 +315,43 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -366,14 +391,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 554a28762..8e8adbeaa 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -4,13 +4,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -18,12 +16,12 @@ USE_PYTORCH_DDP = pytorch_setup()[0] HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) @@ -32,33 +30,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -70,7 +65,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -78,7 +76,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -87,9 +86,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -118,51 +117,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -200,11 +205,13 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng @@ -213,44 +220,47 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters = HPARAMS optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -265,26 +275,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -297,7 +311,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -306,31 +321,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -370,14 +392,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index e4317fa18..5d7c444a4 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -4,13 +4,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -18,12 +16,12 @@ USE_PYTORCH_DDP = pytorch_setup()[0] HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 + 'dropout_rate': 0.1, + 'learning_rate': 0.0017486387539278373, + 'one_minus_beta1': 0.06733926164, + 'beta2': 0.9955159689799007, + 'weight_decay': 0.08121616522670176, + 'warmup_factor': 0.02, } HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) @@ -32,33 +30,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -70,7 +65,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -78,7 +76,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -87,9 +86,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -118,51 +117,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -200,11 +205,13 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng @@ -213,44 +220,47 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters = HPARAMS optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -265,26 +275,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -297,7 +311,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -306,31 +321,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -370,14 +392,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/pyproject.toml b/pyproject.toml index 4e15e4400..b6cfa42cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,9 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ @@ -46,7 +46,6 @@ dependencies = [ "clu==0.0.12", "matplotlib>=3.9.2", "tabulate==0.9.0", - ] [build-system] @@ -69,19 +68,11 @@ version_file = "algoperf/_version.py" ############################################################################### [project.optional-dependencies] # All workloads -full = [ - "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", -] +full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"] # All workloads plus development dependencies full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package -dev = [ - "isort==5.12.0", - "pylint==2.17.4", - "pytest==8.3.3", - "yapf==0.32.0", - "pre-commit==4.0.1", -] +dev = ["ruff==0.12.0", "pytest==8.3.3", "pre-commit==4.0.1"] wandb = ["wandb==0.19.6"] @@ -104,11 +95,7 @@ jax_core_deps = [ "ml_dtypes==0.4.1", "protobuf==4.25.5", ] -jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "algoperf[jax_core_deps]", -] +jax_cpu = ["jax==0.4.28", "jaxlib==0.4.28", "algoperf[jax_core_deps]"] jax_gpu = [ "jax==0.4.28", "jaxlib==0.4.28", @@ -117,217 +104,75 @@ jax_gpu = [ "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] -pytorch_gpu = [ - "torch==2.5.1", - "torchvision==0.20.1", -] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. +pytorch_gpu = ["torch==2.5.1", "torchvision==0.20.1"] +# Note: omit the cuda suffix and installing from the appropriate wheel +# will result in using locally installed CUDA. ############################################################################### -# Linting Configurations # +# Linting & Formatting Configurations # ############################################################################### - -# yapf configuration -[tool.yapf] -based_on_style = "yapf" -each_dict_entry_on_separate_line = false -split_all_top_level_comma_separated_values = true -[tool.yapfignore] -ignore_patterns = ["algoperf/_version.py"] - -# isort configuration -[tool.isort] -profile = "google" - -# pylint configuration -[tool.pylint.MASTER] -persistent = false -ignore = "get_references_web.py,get_references_web_single_group.py,_version.py" - -[tool.pylint.REPORTS] -reports = false -msg-template = "{msg_id}:{line:3} {obj}: {msg} [{symbol}]" - -[tool.pylint.MESSAGES_CONTROL] -enable = "indexing-exception,old-raise-syntax" - -[tool.pylint.BASIC] -# Required attributes for module, separated by a comma -#required-attributes= -# Regular expression which should only match the name -# of functions or classes which do not require a docstring. -no-docstring-rgx = "(__.*__|main)" -# Min length in lines of a function that requires a docstring. -docstring-min-length = 10 -# Regular expression which should only match correct module names. The -# leading underscore is sanctioned for private modules by Google's style -# guide. -# -# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover -# requirements of Python's module system. -module-rgx = "^(_?[a-z][a-z0-9_]*)|__init__$" -# Regular expression which should only match correct module level names -const-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" -# Regular expression which should only match correct class attribute -class-attribute-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" -# Regular expression which should only match correct class names -class-rgx = "^_?[A-Z][a-zA-Z0-9]*$" -# Regular expression which should only match correct function names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. -function-rgx = "^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$" -# Regular expression which should only match correct method names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. 'exempt' indicates a name which is -# consistent with all naming styles. -method-rgx = "(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass|test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|(?:test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$" -# Regular expression which should only match correct instance attribute names -attr-rgx = "^_{0,2}[a-z][a-z0-9_]*$" -# Regular expression which should only match correct argument names -argument-rgx = "^[a-z][a-z0-9_]*$" -# Regular expression which should only match correct variable names -variable-rgx = "^[a-z][a-z0-9_]*$" -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx = "^[a-z][a-z0-9_]*$" -# Good variable names which should always be accepted, separated by a comma -good-names = "main,_" -# Bad variable names which should always be refused, separated by a comma -bad-names = "" -# List of builtins function names that should not be used, separated by a comma -#bad-functions=input,apply,reduce -# List of decorators that define properties, such as abc.abstractproperty. -property-classes = "abc.abstractproperty" - -[tool.pylint.typecheck] -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members = true - -# List of decorators that create context managers from functions, such as -# contextlib.contextmanager. -contextmanager-decorators = [ - "contextlib.contextmanager", - "contextlib2.contextmanager", -] - -[tool.pylint.VARIABLES] -# Tells whether we should check for unused import in __init__ files. -init-import = false - -# A regular expression matching names used for dummy variables (i.e. not used). -dummy-variables-rgx = "^\\*{0,2}(_$|unused_|dummy_)" - -# List of additional names supposed to be defined in builtins. -additional-builtins = [] - -[tool.pylint.CLASSES] -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods = ["__init__", "__new__", "setUp"] - -# Valid names for the first argument to a class method. -valid-classmethod-first-arg = ["cls", "class_"] - -[tool.pylint.EXCEPTIONS] -overgeneral-exceptions = [ - "builtins.StandardError", - "builtins.Exception", - "builtins.BaseException", -] - -[tool.pylint.IMPORTS] -# Deprecated modules which should not be used, separated by a comma -deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec", "sets"] - -[tool.pylint.FORMAT] -# List of checkers and warnings to disable. -disable = [ - "abstract-method", - "access-member-before-definition", - "arguments-differ", - "assignment-from-no-return", - "attribute-defined-outside-init", - "bad-mcs-classmethod-argument", - "bad-option-value", - "c-extension-no-member", - "consider-merging-isinstance", - "consider-using-dict-comprehension", - "consider-using-enumerate", - "consider-using-in", - "consider-using-set-comprehension", - "consider-using-ternary", - "deprecated-method", - "design", - "file-ignored", - "fixme", - "global-statement", - "import-error", - "inconsistent-return-statements", - "invalid-unary-operand-type", - "len-as-condition", - "locally-disabled", - "locally-enabled", - "misplaced-comparison-constant", - "missing-docstring", - "multiple-imports", - "no-else-return", - "no-member", - "no-name-in-module", - "no-self-use", - "no-value-for-parameter", - "not-an-iterable", - "not-context-manager", - "pointless-except", - "protected-access", - "redefined-argument-from-local", - "signature-differs", - "similarities", - "simplifiable-if-expression", - "star-args", - "super-init-not-called", - "suppressed-message", - "too-many-function-args", - "trailing-comma-tuple", - "trailing-newlines", - "ungrouped-imports", - "unnecessary-pass", - "unsubscriptable-object", - "unused-argument", - "useless-object-inheritance", - "useless-return", - "useless-suppression", - "wrong-import-order", - "wrong-import-position", - "unneeded-not", - "unexpected-keyword-arg", - "redundant-keyword-arg", - "unspecified-encoding", - "logging-fstring-interpolation", - "consider-using-f-string", - "use-dict-literal", +[tool.ruff] +line-length = 80 +indent-width = 2 +exclude = ["_version.py"] +target-version = "py311" + +[tool.ruff.format] +quote-style = "single" + +[tool.ruff.lint] +# Could add the commented out rules in the future: +extend-select = [ + "BLE", # disallow catch-all exceptions + "COM", # enforce trailing comma rules + "F", # Pyflakes rules + "FA", # Enforce from __future__ import annotations + "I", # Isort rules + "ICN", # Use common import conventions + "PLE", # Pylint Errors + "TID", # Some good import practices + # "A", # flake8-builtins: detect shadowed builtins + # "B", # flake8-bugbear: + # "C4", # flake8-comprehensions: catch incorrect use of comprehensions + # "D", # pydocstyle + # "DOC", # pydoclint + # "DTZ", # flake8-datetimez: strict timezone manipulation with datetime + # "E", # pycodestyle errors + # "FBT", # flake8-boolean-trap: detect boolean traps + # "ISC", # flake8-implicit-str-concat: good use of string concatenation + # "N", # pep8-naming: enforce naming conventions + # "NPY", # Some numpy-specific things + # "PL", # All Pylint rules + # "PLC", # Pylint Convention + # "PLR", # Pylint Refactor + # "PLW", # Pylint Warnings + # "PTH", # flake8-use-pathlib: use pathlib instead of os.path + # "RET", # flake8-return: good return practices + # "S", # flake8-bandit: security testing + # "SIM", # flake8-simplify: common simplification rules + # "TC", # flake8-type-checking: enforce importing certain types in a TYPE_CHECKING block + # "TD", # flake8-todo: Be diligent with TODO comments + # "UP", # pyupgrade: Warn if things can changed due to newer versions + # "W", # pycodestyle warnings ] -# Maximum number of characters on a single line. -max-line-length = 80 -ignore-long-lines = "(?x)(^\\s*(import|from)\\s|^\\s*(\\#\\ )??$|^[a-zA-Z_][a-zA-Z0-9_]*\\s*=\\s*('[^']\\S+'|\"[^\"]\\S+\"))" -# Maximum number of lines in a module -max-module-lines = 99999 -# String used as indentation unit. We differ from PEP8's normal 4 spaces. -indent-string = ' ' -single-line-if-stmt = true -# Do not warn about multiple statements on a single line for constructs like -# if test: stmt -[tool.pylint.LOGGING] -logging-modules = "logging,absl.logging" -# Add logging modules. -[tool.pylint.MISCELLANEOUS] -# Maximum line length for lambdas -#short-func-length=1 -# List of module members that should be marked as deprecated. -# All of the string functions are listed in 4.1.4 Deprecated string functions -# in the Python 2.4 docs. -#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint -# List of exceptions that do not need to be mentioned in the Raises section of -# a docstring. -#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError -# Number of spaces of indent required when the last token on the preceding line -# is an open (, [, or {. -indent-after-paren = 4 +ignore = [ + # Conflicting lint rules with Ruff's formatter + # (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules). + "W191", + "E111", + "E114", + "E117", + "D206", + "D300", + "Q000", + "Q001", + "Q002", + "Q003", + "COM812", + "COM819", + "ISC001", + "ISC002", + "FBT001", + "FBT003", + "TD003", +] \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 3d8e35eaa..d080f2fb3 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -19,26 +19,29 @@ def get_batch_size(workload_name): def cosine_decay(lr, step, total_steps): - ratio = jnp.maximum(0., step / total_steps) - mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio)) + ratio = jnp.maximum(0.0, step / total_steps) + mult = 0.5 * (1.0 + jnp.cos(jnp.pi * ratio)) return mult * lr -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): +def create_learning_rate_fn( + hparams: spec.Hyperparameters, steps_per_epoch: int +): """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128. + base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128.0 warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) + init_value=0.0, + end_value=base_learning_rate, + transition_steps=hparams.warmup_epochs * steps_per_epoch, + ) cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) + init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) + schedules=[warmup_fn, cosine_fn], + boundaries=[hparams.warmup_epochs * steps_per_epoch], + ) return schedule_fn @@ -46,51 +49,59 @@ def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): steps_per_epoch = num_train_examples // get_batch_size('cifar') learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) opt_init_fn, opt_update_fn = optax.sgd( - nesterov=True, - momentum=hyperparameters.momentum, - learning_rate=learning_rate_fn) + nesterov=True, + momentum=hyperparameters.momentum, + learning_rate=learning_rate_fn, + ) return opt_init_fn, opt_update_fn -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del model_params del model_state del rng - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) + opt_init_fn, opt_update_fn = optimizer( + hyperparameters, workload.num_train_examples + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + rng, +): def _loss_fn(params): """loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn(batch['targets'], logits) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] weight_penalty_params = jax.tree_util.tree_leaves(params) @@ -102,25 +113,27 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -130,21 +143,30 @@ def update_params( optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) new_optimizer_state, new_params, new_model_state = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + per_device_rngs, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -158,14 +180,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index d8b91f83a..e8080fe34 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -3,9 +3,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec @@ -16,56 +14,63 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del workload del model_state del rng - base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128. + base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128.0 optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=base_lr, - momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=base_lr, + momentum=hyperparameters.momentum, + weight_decay=hyperparameters.l2, + ), } scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-5, - end_factor=1., - total_iters=hyperparameters.warmup_epochs) + optimizer_state['optimizer'], + start_factor=1e-5, + end_factor=1.0, + total_iters=hyperparameters.warmup_epochs, + ) cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) + hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1 + ) scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs) + optimizer_state['optimizer'], T_max=cosine_epochs + ) optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs]) + optimizer_state['optimizer'], + schedulers=[scheduler1, scheduler2], + milestones=[hyperparameters.warmup_epochs], + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters @@ -78,15 +83,17 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) + label_batch=batch['targets'], logits_batch=logits_batch + ) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] loss.backward() @@ -99,16 +106,18 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -122,14 +131,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json index 283341705..aa8fcacfd 100644 --- a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json +++ b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json @@ -1,7 +1,7 @@ { - "learning_rate": {"feasible_points": [0.1]}, - "warmup_epochs": {"feasible_points": [5]}, - "num_epochs": {"feasible_points": [200]}, - "l2": {"feasible_points": [5e-4]}, - "momentum": {"feasible_points": [0.9]} + "learning_rate": { "feasible_points": [0.1] }, + "warmup_epochs": { "feasible_points": [5] }, + "num_epochs": { "feasible_points": [200] }, + "l2": { "feasible_points": [5e-4] }, + "momentum": { "feasible_points": [0.9] } } diff --git a/reference_algorithms/development_algorithms/mnist/discrete_space.json b/reference_algorithms/development_algorithms/mnist/discrete_space.json index 310f19e7d..8056d4861 100644 --- a/reference_algorithms/development_algorithms/mnist/discrete_space.json +++ b/reference_algorithms/development_algorithms/mnist/discrete_space.json @@ -1,17 +1,17 @@ [ - { - "learning_rate": 1e-3, - "one_minus_beta_1": 0.999, - "epsilon": 0.9 - }, - { - "learning_rate": 1e-2, - "one_minus_beta_1": 0.99, - "epsilon": 0.99 - }, - { - "learning_rate": 1e-1, - "one_minus_beta_1": 0.9, - "epsilon": 0.999 - } -] \ No newline at end of file + { + "learning_rate": 1e-3, + "one_minus_beta_1": 0.999, + "epsilon": 0.9 + }, + { + "learning_rate": 1e-2, + "one_minus_beta_1": 0.99, + "epsilon": 0.99 + }, + { + "learning_rate": 1e-1, + "one_minus_beta_1": 0.9, + "epsilon": 0.999 + } +] diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index c1f54597d..afdd1bd43 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -18,50 +18,59 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del model_params del model_state del rng - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = optax.chain( - optax.scale_by_adam( - b1=1.0 - hyperparameters.one_minus_beta_1, - b2=0.999, - eps=hyperparameters.epsilon), - optax.scale(-hyperparameters.learning_rate)) + optax.scale_by_adam( + b1=1.0 - hyperparameters.one_minus_beta_1, + b2=0.999, + eps=hyperparameters.epsilon, + ), + optax.scale(-hyperparameters.learning_rate), + ) return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn # We need to jax.pmap here instead of inside update_params because the latter # would recompile the function every step. @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, None, 0, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_update_params(workload: spec.Workload, - opt_update_fn, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - optimizer_state: spec.OptimizerState, - rng: spec.RandomState) -> spec.UpdateReturn: + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, None, 0, 0, 0), + static_broadcasted_argnums=(0, 1), +) +def pmapped_update_params( + workload: spec.Workload, + opt_update_fn, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + optimizer_state: spec.OptimizerState, + rng: spec.RandomState, +) -> spec.UpdateReturn: del hyperparameters def loss_fn(params): logits_batch, new_model_state = workload.model_fn( - params=params, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=params, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn(batch['targets'], logits_batch) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] return loss, new_model_state @@ -69,25 +78,27 @@ def loss_fn(params): grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, new_model_state), grad = grad_fn(current_param_container) grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -98,27 +109,30 @@ def update_params( per_device_rngs = jax.random.split(rng, jax.local_device_count()) optimizer_state, opt_update_fn = optimizer_state new_optimizer_state, updated_params, new_model_state = pmapped_update_params( - workload, - opt_update_fn, - current_param_container, - model_state, - hyperparameters, - batch, - optimizer_state, - per_device_rngs) + workload, + opt_update_fn, + current_param_container, + model_state, + hyperparameters, + batch, + optimizer_state, + per_device_rngs, + ) return (new_optimizer_state, opt_update_fn), updated_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -132,14 +146,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index dedd96793..9940fca6e 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -13,38 +13,41 @@ def get_batch_size(workload_name): return batch_sizes[workload_name] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: del model_state del workload del rng optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999), - eps=hyperparameters.epsilon), + 'optimizer': torch.optim.Adam( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999), + eps=hyperparameters.epsilon, + ), } return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type @@ -59,15 +62,17 @@ def update_params( param.grad = None output, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=output) + label_batch=batch['targets'], logits_batch=output + ) loss = loss_dict['summed'] / loss_dict['n_valid_examples'] loss.backward() optimizer_state['optimizer'].step() @@ -75,16 +80,18 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -98,14 +105,16 @@ def prepare_for_eval(workload: spec.Workload, # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. diff --git a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json index 35b941133..bf7d3f1c1 100644 --- a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json +++ b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json @@ -1,5 +1,5 @@ { - "learning_rate": {"min": 1e-4, "max": 1e-2, "scaling": "log"}, - "one_minus_beta_1": {"min": 0.9, "max": 0.999, "scaling": "log"}, - "epsilon": {"feasible_points": [1e-8, 1e-5, 1e-3]} + "learning_rate": { "min": 1e-4, "max": 1e-2, "scaling": "log" }, + "one_minus_beta_1": { "min": 0.9, "max": 0.999, "scaling": "log" }, + "epsilon": { "feasible_points": [1e-8, 1e-5, 1e-3] } } diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index ff98464ae..32ba97da4 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -30,16 +30,17 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import jax -from jax import numpy as jnp import optax +from jax import numpy as jnp JTensor = Any NestedJTensor = Any NestedHParams = Any -def to_quantized(fvalue: JTensor, - quantized_dtype: jnp.dtype) -> Tuple[JTensor, JTensor]: +def to_quantized( + fvalue: JTensor, quantized_dtype: jnp.dtype +) -> Tuple[JTensor, JTensor]: """Converts floating point values `fvalues` to quantized values. We use a very simple quantization scheme where the range is symmetric around @@ -82,16 +83,17 @@ def to_quantized(fvalue: JTensor, # We first decide the scale. if fvalue.ndim < 1: raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') + f'Input array {fvalue} must have a strictly positive number of ' + 'dimensions.' + ) max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, ...] # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, - bs_expanded, - jnp.ones_like(bs_expanded)) + bs_nonzero = jnp.where( + bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) + ) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) @@ -128,8 +130,8 @@ def adafactor_decay_rate_adam(beta2: float, step_counter: JTensor) -> JTensor: """ step = step_counter beta2 = jnp.array(beta2, dtype=jnp.float32) - t = step + 1. - return beta2 * (1. - jnp.power(beta2, t - 1.)) / (1. - jnp.power(beta2, t)) + t = step + 1.0 + return beta2 * (1.0 - jnp.power(beta2, t - 1.0)) / (1.0 - jnp.power(beta2, t)) def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: @@ -145,7 +147,7 @@ def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: """ step = step_counter exponent = jnp.array(exponent, dtype=jnp.float32) - return 1. - jnp.power((step + 1.), -exponent) + return 1.0 - jnp.power((step + 1.0), -exponent) def reduce_mean(array: JTensor) -> JTensor: @@ -187,6 +189,7 @@ def reduce_rms(array: JTensor) -> JTensor: @dataclasses.dataclass(frozen=True) class _ShardedAdafactorUpdateResult: """Structure containing per-variable info for Adafactor.""" + update: Optional[Any] m: Optional[Any] m_scale: Optional[Any] @@ -197,6 +200,7 @@ class _ShardedAdafactorUpdateResult: class ShardedAdafactorState(NamedTuple): """Overall state of the ShardedAdafactor optimizer.""" + count: JTensor m: Optional[NestedJTensor] m_scale: Optional[NestedJTensor] @@ -208,27 +212,29 @@ class ShardedAdafactorState(NamedTuple): class _ShardedAdafactorHelper: """Helper class to implement optax-based sharded Adafactor.""" - def __init__(self, - learning_rate: optax.Schedule, - weight_decay: Optional[float], - layerwise_adaptation: bool, - decay_method: str, - decay_adam: float, - decay_pow: float, - beta1: float, - clip_threshold: Optional[float], - factored: bool, - epsilon1_grad_sq_reg: float, - quantized_dtype: jnp.dtype, - respect_skip_lp_regularization: bool, - exclude_from_layerwise_adaptation: Optional[List[str]], - per_var_learning_summary: bool, - sort_factored_second_moment_dims: bool, - min_dim_size_to_factor: int, - multiply_by_parameter_scale: bool, - epsilon2_param_scale_reg: float, - maybe_inf_to_nan: bool, - nesterov: bool) -> None: + def __init__( + self, + learning_rate: optax.Schedule, + weight_decay: Optional[float], + layerwise_adaptation: bool, + decay_method: str, + decay_adam: float, + decay_pow: float, + beta1: float, + clip_threshold: Optional[float], + factored: bool, + epsilon1_grad_sq_reg: float, + quantized_dtype: jnp.dtype, + respect_skip_lp_regularization: bool, + exclude_from_layerwise_adaptation: Optional[List[str]], + per_var_learning_summary: bool, + sort_factored_second_moment_dims: bool, + min_dim_size_to_factor: int, + multiply_by_parameter_scale: bool, + epsilon2_param_scale_reg: float, + maybe_inf_to_nan: bool, + nesterov: bool, + ) -> None: """Constructor. See ShardedAdafactor() below.""" self._learning_rate = learning_rate @@ -315,12 +321,13 @@ def should_store_momentum_in_qint(self, shape): def to_state(self, count, result_tree): """Maps from a tree of (factored) values to separate trees of values.""" return ShardedAdafactorState( - count=count, - m=jax.tree.map(lambda o: o.m, result_tree), - m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), - vr=jax.tree.map(lambda o: o.vr, result_tree), - vc=jax.tree.map(lambda o: o.vc, result_tree), - v=jax.tree.map(lambda o: o.v, result_tree)) + count=count, + m=jax.tree.map(lambda o: o.m, result_tree), + m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), + vr=jax.tree.map(lambda o: o.vr, result_tree), + vc=jax.tree.map(lambda o: o.vc, result_tree), + v=jax.tree.map(lambda o: o.v, result_tree), + ) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -353,12 +360,13 @@ def init(self, param): else: output_v = jnp.zeros(shape, dtype=jnp.float32) return _ShardedAdafactorUpdateResult( - update=output_update, - m=output_m, - m_scale=output_m_scale, - vr=output_vr, - vc=output_vc, - v=output_v) + update=output_update, + m=output_m, + m_scale=output_m_scale, + vr=output_vr, + vc=output_vc, + v=output_v, + ) def inf_to_nan(self, array): """Converting Infinity values to the more sticky NaN.""" @@ -386,16 +394,9 @@ def parameter_scale(self, var): """ return jnp.maximum(reduce_rms(var), jnp.asarray(self._epsilon2, var.dtype)) - def compute_var_and_slot_update(self, - count, - grad, - m, - m_scale, - vr, - vc, - v, - param, - var_name=None): + def compute_var_and_slot_update( + self, count, grad, m, m_scale, vr, vc, v, param, var_name=None + ): """Computes the var and optimizer slots updates for a single variable.""" # We can probably skip this step grad = grad.astype(jnp.float32) @@ -434,7 +435,7 @@ def compute_var_and_slot_update(self, update_scale += grad_squared_mean * 1e-30 # END HACK - mixing_rate = 1. - decay_rate + mixing_rate = 1.0 - decay_rate shape = param.shape output_m = jnp.zeros((1,)) @@ -449,18 +450,23 @@ def compute_var_and_slot_update(self, # reduce_mean(). vr_axis, vc_axis = factored_second_moment_dims grad_squared_row_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vr_axis)) + jnp.mean(grad_squared, axis=vr_axis) + ) grad_squared_col_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vc_axis)) + jnp.mean(grad_squared, axis=vc_axis) + ) new_vr = decay_rate * vr + mixing_rate * grad_squared_row_mean new_vc = decay_rate * vc + mixing_rate * grad_squared_col_mean output_vr = new_vr output_vc = new_vc long_term_mean = jnp.mean(new_vr, axis=-1, keepdims=True) - r_factor = 1. / jnp.sqrt(new_vr / long_term_mean) - c_factor = 1. / jnp.sqrt(new_vc) - x = grad * jnp.expand_dims(r_factor, vr_axis) * jnp.expand_dims( - c_factor, vc_axis) + r_factor = 1.0 / jnp.sqrt(new_vr / long_term_mean) + c_factor = 1.0 / jnp.sqrt(new_vc) + x = ( + grad + * jnp.expand_dims(r_factor, vr_axis) + * jnp.expand_dims(c_factor, vc_axis) + ) else: # v with sharding annotation. new_v = decay_rate * v + mixing_rate * grad_squared @@ -468,7 +474,7 @@ def compute_var_and_slot_update(self, x = grad / jnp.sqrt(new_v) if self._clip_threshold is not None: - clipping_denom = jnp.maximum(1., reduce_rms(x) / self._clip_threshold) + clipping_denom = jnp.maximum(1.0, reduce_rms(x) / self._clip_threshold) clipping_denom = self.inf_to_nan(clipping_denom) x /= clipping_denom @@ -481,7 +487,7 @@ def compute_var_and_slot_update(self, m = to_float(m, m_scale) if self._nesterov: subtrahend_original = subtrahend - subtrahend = self._beta1 * m + (1. - self._beta1) * subtrahend + subtrahend = self._beta1 * m + (1.0 - self._beta1) * subtrahend subtrahend = self.inf_to_nan(subtrahend) if self._quantized_dtype == jnp.bfloat16: new_m = subtrahend.astype(jnp.bfloat16) @@ -496,8 +502,8 @@ def compute_var_and_slot_update(self, if self._nesterov: subtrahend = ( - self._beta1 * subtrahend + - (1.0 - self._beta1) * subtrahend_original) + self._beta1 * subtrahend + (1.0 - self._beta1) * subtrahend_original + ) if self._weight_decay is not None: # Apply decoupled weight decay to be consistent with AdamW. @@ -527,43 +533,45 @@ def compute_var_and_slot_update(self, g_norm = reduce_rms(subtrahend / update_scale) + self._epsilon1 ratio = w_norm / g_norm ratio = jnp.where( - jnp.greater(w_norm, 0), - jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), - 1.0) + jnp.greater(w_norm, 0), + jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), + 1.0, + ) subtrahend *= ratio return _ShardedAdafactorUpdateResult( - update=-subtrahend, - m=output_m, - m_scale=output_m_scale, - vr=output_vr, - vc=output_vc, - v=output_v) + update=-subtrahend, + m=output_m, + m_scale=output_m_scale, + vr=output_vr, + vc=output_vc, + v=output_v, + ) def sharded_adafactor( - learning_rate: optax.Schedule, - weight_decay: Optional[Union[float, Dict[str, float]]] = None, - layerwise_adaptation: bool = False, - decay_method: str = 'adam', - decay_adam: float = 0.99, - decay_pow: float = 0., - beta1: float = 0.9, - clip_threshold: Optional[float] = 1., - factored: bool = True, - epsilon1_grad_sq_reg: float = 1e-30, - quantized_dtype: jnp.dtype = jnp.int8, - respect_skip_lp_regularization: bool = False, - exclude_from_layerwise_adaptation: Optional[List[str]] = None, - per_var_learning_summary: bool = False, - sort_factored_second_moment_dims: bool = False, - # min_dim_size_to_factor is only used when - # sort_factored_second_moment_dims=True. - min_dim_size_to_factor: int = 128, - multiply_by_parameter_scale: bool = False, - epsilon2_param_scale_reg: float = 1e-3, - maybe_inf_to_nan: bool = True, - nesterov: bool = False, + learning_rate: optax.Schedule, + weight_decay: Optional[Union[float, Dict[str, float]]] = None, + layerwise_adaptation: bool = False, + decay_method: str = 'adam', + decay_adam: float = 0.99, + decay_pow: float = 0.0, + beta1: float = 0.9, + clip_threshold: Optional[float] = 1.0, + factored: bool = True, + epsilon1_grad_sq_reg: float = 1e-30, + quantized_dtype: jnp.dtype = jnp.int8, + respect_skip_lp_regularization: bool = False, + exclude_from_layerwise_adaptation: Optional[List[str]] = None, + per_var_learning_summary: bool = False, + sort_factored_second_moment_dims: bool = False, + # min_dim_size_to_factor is only used when + # sort_factored_second_moment_dims=True. + min_dim_size_to_factor: int = 128, + multiply_by_parameter_scale: bool = False, + epsilon2_param_scale_reg: float = 1e-3, + maybe_inf_to_nan: bool = True, + nesterov: bool = False, ) -> optax.GradientTransformation: """AdaFactor optimizer that supports SPMD sharding. @@ -638,53 +646,60 @@ def sharded_adafactor( assert decay_pow >= 0 assert learning_rate is not None assert decay_method == 'adam' or decay_method == 'pow', ( - f'decay_method: {decay_method} not supported. Supported methods are ' - '"pow", or "adam".') + f'decay_method: {decay_method} not supported. Supported methods are ' + '"pow", or "adam".' + ) sharded_adafactor_helper = _ShardedAdafactorHelper( - learning_rate=learning_rate, - weight_decay=weight_decay, - layerwise_adaptation=layerwise_adaptation, - decay_method=decay_method, - decay_adam=decay_adam, - decay_pow=decay_pow, - beta1=beta1, - clip_threshold=clip_threshold, - factored=factored, - epsilon1_grad_sq_reg=epsilon1_grad_sq_reg, - quantized_dtype=quantized_dtype, - respect_skip_lp_regularization=respect_skip_lp_regularization, - exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation, - per_var_learning_summary=per_var_learning_summary, - sort_factored_second_moment_dims=sort_factored_second_moment_dims, - min_dim_size_to_factor=min_dim_size_to_factor, - multiply_by_parameter_scale=multiply_by_parameter_scale, - epsilon2_param_scale_reg=epsilon2_param_scale_reg, - maybe_inf_to_nan=maybe_inf_to_nan, - nesterov=nesterov) + learning_rate=learning_rate, + weight_decay=weight_decay, + layerwise_adaptation=layerwise_adaptation, + decay_method=decay_method, + decay_adam=decay_adam, + decay_pow=decay_pow, + beta1=beta1, + clip_threshold=clip_threshold, + factored=factored, + epsilon1_grad_sq_reg=epsilon1_grad_sq_reg, + quantized_dtype=quantized_dtype, + respect_skip_lp_regularization=respect_skip_lp_regularization, + exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation, + per_var_learning_summary=per_var_learning_summary, + sort_factored_second_moment_dims=sort_factored_second_moment_dims, + min_dim_size_to_factor=min_dim_size_to_factor, + multiply_by_parameter_scale=multiply_by_parameter_scale, + epsilon2_param_scale_reg=epsilon2_param_scale_reg, + maybe_inf_to_nan=maybe_inf_to_nan, + nesterov=nesterov, + ) def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( - jnp.zeros([], jnp.int32), - jax.tree.map(sharded_adafactor_helper.init, params)) + jnp.zeros([], jnp.int32), + jax.tree.map(sharded_adafactor_helper.init, params), + ) def update_fn(updates, state, params=None): if params is None: raise ValueError( - 'You are using a transformation that requires the current value of ' - 'parameters, but you are not passing `params` when calling `update`.') + 'You are using a transformation that requires the current value of ' + 'parameters, but you are not passing `params` when calling `update`.' + ) compute_var_and_slot_update_fn = functools.partial( - sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree.map(compute_var_and_slot_update_fn, - updates, - state.m, - state.m_scale, - state.vr, - state.vc, - state.v, - params) + sharded_adafactor_helper.compute_var_and_slot_update, state.count + ) + output = jax.tree.map( + compute_var_and_slot_update_fn, + updates, + state.m, + state.m_scale, + state.vr, + state.vc, + state.v, + params, + ) updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 1833ab8af..8dcaa6578 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -3,24 +3,27 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec -from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \ - sharded_adafactor +from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import ( + sharded_adafactor, +) _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an Adafactor optimizer and a learning rate schedule.""" del model_params del model_state @@ -30,99 +33,113 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = sharded_adafactor( - learning_rate=lr_schedule_fn, - beta1=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + beta1=1.0 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -139,37 +156,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -205,14 +228,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 7aa457a25..4c96e5562 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -3,12 +3,10 @@ from functools import partial from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -16,36 +14,40 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an Adafactor optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - Adafactor( - model_params.parameters(), - lr=hyperparameters.learning_rate, - beta1=1 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay), + 'optimizer': Adafactor( + model_params.parameters(), + lr=hyperparameters.learning_rate, + beta1=1 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + ), } optimizer = optimizer_state['optimizer'] warmup = LinearLR( - optimizer, - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) + optimizer, + start_factor=1e-10, + end_factor=1.0, + total_iters=hyperparameters.warmup_steps, + ) cosine_steps = max(workload.step_hint - hyperparameters.warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) optimizer_state['scheduler'] = SequentialLR( - optimizer, - schedulers=[warmup, cosine_decay], - milestones=[hyperparameters.warmup_steps]) + optimizer, + schedulers=[warmup, cosine_decay], + milestones=[hyperparameters.warmup_steps], + ) return optimizer_state @@ -54,56 +56,56 @@ class Adafactor(torch.optim.Optimizer): src/transformers/optimization.py#L386""" def __init__( - self, - params, - lr=None, - beta1=0.9, - decay_adam=0.99, - weight_decay=0.0, + self, + params, + lr=None, + beta1=0.9, + decay_adam=0.99, + weight_decay=0.0, ): defaults = dict( - lr=lr, - beta1=beta1, - decay_adam=decay_adam, - weight_decay=weight_decay, - decay_pow=0.0, - layerwise_adaptation=False, - decay_method='adam', - clip_threshold=1.0, - factored=True, - epsilon1_grad_sq_reg=1e-30, - respect_skip_lp_regularization=False, - exclude_from_layerwise_adaptation=None, - per_var_learning_summary=False, - sort_factored_second_moment_dims=False, - # Unused because sort_factored_second_moment_dims=False. - min_dim_size_to_factor=128, - multiply_by_parameter_scale=False, - # Unused because multiply_by_parameter_scale=False. - epsilon2_param_scale_reg=1e-3, - maybe_inf_to_nan=True, + lr=lr, + beta1=beta1, + decay_adam=decay_adam, + weight_decay=weight_decay, + decay_pow=0.0, + layerwise_adaptation=False, + decay_method='adam', + clip_threshold=1.0, + factored=True, + epsilon1_grad_sq_reg=1e-30, + respect_skip_lp_regularization=False, + exclude_from_layerwise_adaptation=None, + per_var_learning_summary=False, + sort_factored_second_moment_dims=False, + # Unused because sort_factored_second_moment_dims=False. + min_dim_size_to_factor=128, + multiply_by_parameter_scale=False, + # Unused because multiply_by_parameter_scale=False. + epsilon2_param_scale_reg=1e-3, + maybe_inf_to_nan=True, ) super().__init__(params, defaults) def inf_to_nan(self, group, x): - if group["maybe_inf_to_nan"]: + if group['maybe_inf_to_nan']: x = torch.nan_to_num(x, nan=torch.nan, posinf=torch.nan, neginf=torch.nan) return x def step(self, closure=None): """ - Performs a single optimization step - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ + Performs a single optimization step + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ loss = None if closure is not None: loss = closure() for group in self.param_groups: inf_to_nan = partial(self.inf_to_nan, group) - for p in group["params"]: + for p in group['params']: if p.grad is None: continue grad = p.grad.data @@ -111,7 +113,7 @@ def step(self, closure=None): if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError("Adafactor does not support sparse gradients.") + raise RuntimeError('Adafactor does not support sparse gradients.') state = self.state[p] grad_shape = grad.shape @@ -120,51 +122,54 @@ def step(self, closure=None): # State Initialization if len(state) == 0: - state["step"] = 0 - state["exp_avg"] = torch.zeros_like(grad) + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(grad) if factored: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + - grad_shape[-1:]).to(grad) + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).to(grad) else: - state["exp_avg_sq"] = torch.zeros_like(grad) + state['exp_avg_sq'] = torch.zeros_like(grad) else: - state["exp_avg"] = state["exp_avg"].to(grad) + state['exp_avg'] = state['exp_avg'].to(grad) if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) p_data_fp32 = p.data if p.data.dtype in {torch.float16, torch.bfloat16}: p_data_fp32 = p_data_fp32.float() - state["step"] += 1 - lr = group["lr"] - beta1 = group["beta1"] - beta2 = group["decay_adam"] + state['step'] += 1 + lr = group['lr'] + beta1 = group['beta1'] + beta2 = group['decay_adam'] - t = state["step"] - beta2t = beta2 * (1. - beta2**(t - 1.)) / (1. - beta2**t) + t = state['step'] + beta2t = beta2 * (1.0 - beta2 ** (t - 1.0)) / (1.0 - beta2**t) - exp_avg_sq_update = (grad**2) + group["epsilon1_grad_sq_reg"] + exp_avg_sq_update = (grad**2) + group['epsilon1_grad_sq_reg'] if factored: - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] exp_avg_sq_row.mul_(beta2t).add_( - exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t) + exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t + ) exp_avg_sq_col.mul_(beta2t).add_( - exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t) + exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t + ) r_factor = inf_to_nan( - exp_avg_sq_row / - exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1) + exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True) + ).unsqueeze(-1) c_factor = inf_to_nan(exp_avg_sq_col).unsqueeze(-2) denom = r_factor * c_factor else: - exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq = state['exp_avg_sq'] exp_avg_sq.mul_(beta2t).add_(exp_avg_sq_update, alpha=1.0 - beta2t) denom = exp_avg_sq @@ -172,15 +177,16 @@ def step(self, closure=None): denom = denom.sqrt() update = grad / denom # Clip the update based on RMS. - clipping_denom = inf_to_nan(torch.square(update).mean().sqrt() \ - /group["clip_threshold"]).clamp(min=1.0) + clipping_denom = inf_to_nan( + torch.square(update).mean().sqrt() / group['clip_threshold'] + ).clamp(min=1.0) update = update / clipping_denom * lr # Momentum - exp_avg = state["exp_avg"] + exp_avg = state['exp_avg'] exp_avg.mul_(beta1).add_(update, alpha=1 - beta1) - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * lr) + if group['weight_decay'] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr) p_data_fp32.add_(-exp_avg) @@ -191,18 +197,19 @@ def step(self, closure=None): def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -214,22 +221,26 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -243,12 +254,14 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -256,28 +269,34 @@ def update_params( if global_step <= 100 or global_step % 500 == 0: if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -314,14 +333,15 @@ def get_batch_size(workload_name): def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json index 5543689ea..37b36e55d 100644 --- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json @@ -1,20 +1,26 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 1e-2, "max": 0.45, "scaling": "log" - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 1e-2, + "max": 0.45, + "scaling": "log" + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json index 98a506084..35f106f9d 100644 --- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json @@ -1,20 +1,24 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..17d0c2fc2 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -3,22 +3,24 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,101 +30,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -139,37 +155,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -209,14 +231,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 21d9b6b57..5df907160 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -2,12 +2,10 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -15,55 +13,60 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay, - fused=False), + 'optimizer': torch.optim.AdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + fused=False, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = hyperparameters.warmup_factor * step_hint warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -75,22 +78,26 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -104,7 +111,8 @@ def update_params( if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -113,31 +121,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -177,14 +192,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json index c96b03eda..abdd6e32d 100644 --- a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json @@ -1,23 +1,29 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 2e-2, "max": 0.5, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 2e-2, + "max": 0.5, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json index b8bd2ea49..5a7c27be7 100644 --- a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 70e305514..168c0579b 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -21,11 +21,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a LAMB optimizer and a learning rate schedule.""" del model_params del model_state @@ -35,61 +37,70 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = optax.lamb( - learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] @@ -97,40 +108,45 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -147,37 +163,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -213,14 +235,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index c1c6cec0a..c73f89e71 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -3,25 +3,19 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch +from absl import logging from torch import Tensor -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py class LAMB(torch.optim.Optimizer): - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0.0): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -39,7 +33,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -74,48 +69,53 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) lamb( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def lamb(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float): - +def lamb( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -147,61 +147,67 @@ def lamb(params: List[Tensor], update_norm = torch.linalg.norm(update) # Set trust_ratio to 1 in case where parameters would never be updated. - if param_norm == 0. or update_norm == 0.: - trust_ratio = 1. + if param_norm == 0.0 or update_norm == 0.0: + trust_ratio = 1.0 else: trust_ratio = param_norm / update_norm param.add_(update, alpha=-lr * trust_ratio) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a LAMB optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - LAMB( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay) + 'optimizer': LAMB( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -213,31 +219,36 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss, _ = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) loss.backward() if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -246,31 +257,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -306,14 +324,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json index f2fcde461..7de33fe47 100644 --- a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json @@ -1,23 +1,29 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 5e-2, "max": 0.3, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 5e-2, + "max": 0.3, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json index 8934e512d..0f1bc208a 100644 --- a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cbb6d6dcd..df084c17b 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -3,22 +3,24 @@ import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=1.0 - hyperparameters.one_minus_beta1, - nesterov=False) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=1.0 - hyperparameters.one_minus_beta1, + nesterov=False, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn @@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -173,37 +191,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -243,14 +267,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index c3760d20e..b3d38b3dd 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -2,10 +2,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from absl import logging import optax import torch import torch.distributed.nn as dist_nn +from absl import logging from torch.optim.lr_scheduler import LambdaLR from algoperf import spec @@ -14,24 +14,26 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=False), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=1.0 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=False, + ), } # Create learning rate schedule. @@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -91,26 +98,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -123,7 +134,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -132,31 +144,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -196,14 +215,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json index 8423bdab7..9ec39a6ef 100644 --- a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json @@ -1,21 +1,27 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "min": 5e-3, "max": 0.3, "scaling": "log" + "min": 5e-3, + "max": 0.3, + "scaling": "log" }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json index f874862d8..80f9c7968 100644 --- a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json @@ -1,21 +1,25 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "feasible_points": [0.1] + "feasible_points": [0.1] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index c451a18ac..62161b3d5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -1,26 +1,24 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -30,15 +28,14 @@ # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -73,19 +70,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -124,7 +124,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -132,6 +133,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -140,7 +142,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -156,11 +159,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state @@ -170,101 +175,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -281,37 +300,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -351,14 +376,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index a2f9fb4c5..f6c2faa9d 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -3,13 +3,11 @@ import math from typing import Any, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch -from torch import Tensor import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch import Tensor +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -21,33 +19,30 @@ class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -59,7 +54,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -67,7 +65,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -76,9 +75,9 @@ def __setstate__(self, state): def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. """ self._cuda_graph_capture_health_check() @@ -107,51 +106,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +) -> None: r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. + See NAdamW class for details. """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -189,54 +194,59 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -248,26 +258,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -280,7 +294,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -289,31 +304,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -353,14 +375,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json index cba20c4c2..a3d322771 100644 --- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json @@ -1,23 +1,29 @@ { "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" + "min": 1e-4, + "max": 1e-2, + "scaling": "log" }, "one_minus_beta1": { - "min": 4e-3, "max": 0.1, "scaling": "log" + "min": 4e-3, + "max": 0.1, + "scaling": "log" }, "beta2": { - "feasible_points": [0.999] + "feasible_points": [0.999] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" + "min": 5e-3, + "max": 1.0, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] } } diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json index 58973eb27..5a7c27be7 100644 --- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" + "min": 1e-4, + "max": 1e-2, + "scaling": "log" }, "one_minus_beta1": { - "feasible_points": [0.1] + "feasible_points": [0.1] }, "beta2": { - "feasible_points": [0.999] + "feasible_points": [0.999] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" + "min": 5e-3, + "max": 1.0, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] } } diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..18e58b3c0 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -3,22 +3,24 @@ import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=1.0 - hyperparameters.one_minus_beta1, - nesterov=True) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=1.0 - hyperparameters.one_minus_beta1, + nesterov=True, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn @@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -173,37 +191,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -243,14 +267,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index b4432fbff..9d3bfa6e7 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -2,10 +2,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from absl import logging import optax import torch import torch.distributed.nn as dist_nn +from absl import logging from torch.optim.lr_scheduler import LambdaLR from algoperf import spec @@ -14,24 +14,26 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=True), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=1.0 - hyperparameters.one_minus_beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=True, + ), } # Create learning rate schedule. @@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) decay_steps = step_hint - warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps] + ) return lr_schedule_fn def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -91,26 +98,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -123,7 +134,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() @@ -132,31 +144,38 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -196,14 +215,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json index 8423bdab7..9ec39a6ef 100644 --- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json @@ -1,21 +1,27 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "min": 5e-3, "max": 0.3, "scaling": "log" + "min": 5e-3, + "max": 0.3, + "scaling": "log" }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json index f874862d8..80f9c7968 100644 --- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json @@ -1,21 +1,25 @@ { "learning_rate": { - "min": 2e-2, "max": 5.0, "scaling": "log" + "min": 2e-2, + "max": 5.0, + "scaling": "log" }, "one_minus_beta1": { - "feasible_points": [0.1] + "feasible_points": [0.1] }, "warmup_factor": { - "feasible_points": [0.05] + "feasible_points": [0.05] }, "weight_decay": { - "min": 1e-7, "max": 5e-5, "scaling": "log" + "min": 1e-7, + "max": 5e-5, + "scaling": "log" }, "label_smoothing": { - "feasible_points": [0.1, 0.2] + "feasible_points": [0.1, 0.2] }, "dropout_rate": { - "feasible_points": [0.0, 0.1] + "feasible_points": [0.0, 0.1] }, "decay_steps_factor": { "feasible_points": [0.9] diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index b76589705..9ab193b56 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -3,11 +3,11 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec @@ -23,7 +23,8 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: y: A pytree of numpy ndarray, vector y in the equation above. """ gradient_norm = jnp.sqrt( - sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) + sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)) + ) normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -31,11 +32,11 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/ # sharpness_aware_minimization.py def sharpness_aware_minimization( - rho: float, - grad_clip: Optional[float], - batch_axis_name: str, - base_opt_init_fn, - base_opt_update_fn, + rho: float, + grad_clip: Optional[float], + batch_axis_name: str, + base_opt_init_fn, + base_opt_update_fn, ) -> optax.GradientTransformation: """Implementation of Sharpness Aware Minimization (SAM). Paper: https://arxiv.org/abs/2010.01412 @@ -68,22 +69,28 @@ def update_fn(updates, state, grad_fn_params_tuple): # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + lambda p, u: p + rho * u, params, updates + ) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. - (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), - axis_name=batch_axis_name) + (n_valid_examples, updates) = lax.psum( + (n_valid_examples, updates), axis_name=batch_axis_name + ) updates = jax.tree.map(lambda x: x / n_valid_examples, updates) if grad_clip: updates_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)) + ) scaled_updates = jax.tree.map( - lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, - lambda _: updates, - None) + lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates + ) + updates = jax.lax.cond( + updates_norm > grad_clip, + lambda _: scaled_updates, + lambda _: updates, + None, + ) updates, state = base_opt_update_fn(updates, state, params) return updates, state @@ -91,11 +98,13 @@ def update_fn(updates, state, grad_fn_params_tuple): return optax.GradientTransformation(init_fn, update_fn) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a SAM optimizer (with AdamW base) and a learning rate schedule.""" del model_params del model_state @@ -105,111 +114,127 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create base optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ) # Create SAM update fn. grad_clip = ( - hyperparameters.grad_clip - if hasattr(hyperparameters, 'grad_clip') else None) + hyperparameters.grad_clip if hasattr(hyperparameters, 'grad_clip') else None + ) opt_init_fn, opt_update_fn = sharpness_aware_minimization( - rho=hyperparameters.rho, - grad_clip=grad_clip, - batch_axis_name='batch', - base_opt_init_fn=opt_init_fn, - base_opt_update_fn=opt_update_fn) + rho=hyperparameters.rho, + grad_clip=grad_clip, + batch_axis_name='batch', + base_opt_init_fn=opt_init_fn, + base_opt_update_fn=opt_update_fn, + ) # Initialize optimizer state. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params, update_batch_norm=True): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=update_batch_norm) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=update_batch_norm, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) second_grad_fn = jax.value_and_grad( - functools.partial(_loss_fn, update_batch_norm=False), has_aux=True) + functools.partial(_loss_fn, update_batch_norm=False), has_aux=True + ) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, (second_grad_fn, current_param_container)) + grad, optimizer_state, (second_grad_fn, current_param_container) + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -226,37 +251,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -293,14 +324,15 @@ def get_batch_size(workload_name): def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 92603f036..652ebed1d 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -2,12 +2,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -17,13 +15,14 @@ # Modified from https://github.com/davda54/sam. class SAM(torch.optim.Optimizer): - - def __init__(self, - params: spec.ParameterContainer, - base_optimizer: torch.optim.Optimizer, - rho: float = 0.05, - adaptive: bool = False, - **kwargs): + def __init__( + self, + params: spec.ParameterContainer, + base_optimizer: torch.optim.Optimizer, + rho: float = 0.05, + adaptive: bool = False, + **kwargs, + ): if rho < 0.0: raise ValueError(f'Invalid rho, should be non-negative: {rho}') @@ -79,12 +78,18 @@ def _grad_norm(self): # In case of model parallelism, put everything on the same device. shared_device = self.param_groups[0]['params'][0].device norm = torch.norm( - torch.stack([((torch.abs(p) if group['adaptive'] else 1.0) * - p.grad).norm(p=2).to(shared_device) - for group in self.param_groups - for p in group['params'] - if p.grad is not None]), - p=2) + torch.stack( + [ + ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad) + .norm(p=2) + .to(shared_device) + for group in self.param_groups + for p in group['params'] + if p.grad is not None + ] + ), + p=2, + ) return norm def load_state_dict(self, state_dict: Dict): @@ -92,11 +97,13 @@ def load_state_dict(self, state_dict: Dict): self.base_optimizer.param_groups = self.param_groups -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_state del rng @@ -104,46 +111,50 @@ def init_optimizer_state(workload: spec.Workload, # Create SAM optimizer with AdamW base. base_optimizer = torch.optim.AdamW optimizer_state = { - 'optimizer': - SAM(model_params.parameters(), - base_optimizer=base_optimizer, - rho=hyperparameters.rho, - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), + 'optimizer': SAM( + model_params.parameters(), + base_optimizer=base_optimizer, + rho=hyperparameters.rho, + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay, + ), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) # Create learning rate schedule. optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) + workload.step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -156,20 +167,24 @@ def update_params( def _loss_fn(params, update_batch_norm=True): """Loss function used for training.""" logits_batch, new_model_state = workload.model_fn( - params=params, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=update_batch_norm) + params=params, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=update_batch_norm, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -187,7 +202,8 @@ def _loss_fn(params, update_batch_norm=True): with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) optimizer_state['optimizer'].first_step(zero_grad=True) @@ -198,7 +214,8 @@ def _loss_fn(params, update_batch_norm=True): if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].second_step(zero_grad=True) optimizer_state['scheduler'].step() @@ -206,29 +223,34 @@ def _loss_fn(params, update_batch_norm=True): if global_step <= 100 or global_step % 500 == 0: if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': logging_loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - logging_loss.item(), - grad_norm.item()) + { + 'loss': logging_loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + logging_loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -265,14 +287,15 @@ def get_batch_size(workload_name): def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space.json b/reference_algorithms/paper_baselines/sam/tuning_search_space.json index 66dae232b..f32058937 100644 --- a/reference_algorithms/paper_baselines/sam/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/sam/tuning_search_space.json @@ -1,26 +1,32 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 5e-2, "max": 0.43, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-2, "max": 0.2, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - }, - "rho": { - "feasible_points": [0.01, 0.02, 0.05] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 5e-2, + "max": 0.43, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-2, + "max": 0.2, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + }, + "rho": { + "feasible_points": [0.01, 0.02, 0.05] + } } diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json index 89c480e7a..ee4e0c3e4 100644 --- a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json @@ -1,26 +1,30 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-2, "max": 0.2, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - }, - "rho": { - "feasible_points": [0.01, 0.02, 0.05] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 1e-2, + "max": 0.2, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + }, + "rho": { + "feasible_points": [0.01, 0.02, 0.05] + } } diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index a5c2732ac..c719361d3 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -36,17 +36,17 @@ import functools import itertools import logging -from typing import Any, cast, List, NamedTuple, Optional, TypeVar, Union +from typing import Any, List, NamedTuple, Optional, TypeVar, Union, cast import chex -from flax import struct import jax -from jax import lax -from jax.experimental import pjit -from jax.experimental.sparse import linalg import jax.numpy as jnp import numpy as np import optax +from flax import struct +from jax import lax +from jax.experimental import pjit +from jax.experimental.sparse import linalg # Dtype for inverse-pth root routine # Switch to f64 if you have hardware that supports it. Enable the jax flag @@ -61,13 +61,16 @@ @struct.dataclass class QuantizedValue: """State associated with quantized value.""" + quantized: chex.Array diagonal: chex.Array # Diagonal (if extract_diagonal is set) bucket_size: chex.Array quantized_dtype: jnp.dtype = struct.field( - pytree_node=False) # Dtype for the quantized value. + pytree_node=False + ) # Dtype for the quantized value. extract_diagonal: bool = struct.field( - pytree_node=False) # In case its centered. + pytree_node=False + ) # In case its centered. shape: Any = struct.field(pytree_node=False) # Shape of the tensor. @classmethod @@ -75,13 +78,16 @@ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): if isinstance(fvalue, list) and not fvalue: return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( - fvalue, quantized_dtype, extract_diagonal) - return QuantizedValue(quantized, - diagonal_fvalue, - bucket_size, - quantized_dtype, - extract_diagonal, - list(quantized.shape)) + fvalue, quantized_dtype, extract_diagonal + ) + return QuantizedValue( + quantized, + diagonal_fvalue, + bucket_size, + quantized_dtype, + extract_diagonal, + list(quantized.shape), + ) # Quantization is from Lingvo JAX optimizers. # We extend it for int16 quantization of PSD matrices. @@ -106,7 +112,8 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): if extract_diagonal and fvalue.ndim != 2: raise ValueError( - f'Input array {fvalue} must be 2D to work with extract_diagonal.') + f'Input array {fvalue} must be 2D to work with extract_diagonal.' + ) diagonal_fvalue = [] if extract_diagonal: @@ -119,16 +126,17 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): # We first decide the scale. if fvalue.ndim < 1: raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') + f'Input array {fvalue} must have a strictly positive number of ' + 'dimensions.' + ) max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, Ellipsis] # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, - bs_expanded, - jnp.ones_like(bs_expanded)) + bs_nonzero = jnp.where( + bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) + ) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) @@ -155,10 +163,11 @@ def to_float(self): def _default_zero_field(): return struct.field( - default_factory=functools.partial(jnp.array, 0, jnp.float32)) + default_factory=functools.partial(jnp.array, 0, jnp.float32) + ) -T = TypeVar("T") +T = TypeVar('T') def _maybe_ix(ls, ix): @@ -180,17 +189,19 @@ def wrap_f(x, *args, **kwargs): InversePthRootDiagnosticsSubtype = TypeVar( - "InversePthRootDiagnosticsSubtype", bound="InversePthRootDiagnostics") + 'InversePthRootDiagnosticsSubtype', bound='InversePthRootDiagnostics' +) @struct.dataclass class InversePthRootDiagnostics: """Diagnostics for inverse p-th root iterative procedure. - Given an inverse pth root B = A^(-1/p), contains the average and - maximum diagonal and off diagonal absolute entrywise errors between - (B^p A) and I. - """ + Given an inverse pth root B = A^(-1/p), contains the average and + maximum diagonal and off diagonal absolute entrywise errors between + (B^p A) and I. + """ + max_diag_error: chex.Array = _default_zero_field() avg_diag_error: chex.Array = _default_zero_field() max_off_diag_error: chex.Array = _default_zero_field() @@ -201,35 +212,41 @@ class InversePthRootDiagnostics: def create(cls, pth_inverse_root, matrix, p): """Generates a diagnostics struct from (-1/p) root result.""" mat_m = jnp.matmul( - mat_power(pth_inverse_root, p), - matrix, - precision=jax.lax.Precision.HIGHEST) + mat_power(pth_inverse_root, p), + matrix, + precision=jax.lax.Precision.HIGHEST, + ) num_off_diag_entries = mat_m.size - jnp.diag(mat_m).size diag_error = jnp.abs(jnp.diag(mat_m) - 1).astype(jnp.float32) off_diag_error = jnp.abs(mat_m - jnp.diag(jnp.diag(mat_m))).astype( - jnp.float32) + jnp.float32 + ) return cls( - max_diag_error=jnp.max(diag_error).astype(jnp.float32), - avg_diag_error=jnp.mean(diag_error).astype(jnp.float32), - max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32), - avg_off_diag_error=(jnp.sum(off_diag_error) / - num_off_diag_entries).astype(jnp.float32), - p=jnp.array(p, jnp.float32)) + max_diag_error=jnp.max(diag_error).astype(jnp.float32), + avg_diag_error=jnp.mean(diag_error).astype(jnp.float32), + max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32), + avg_off_diag_error=( + jnp.sum(off_diag_error) / num_off_diag_entries + ).astype(jnp.float32), + p=jnp.array(p, jnp.float32), + ) LOBPCGDiagnosticsSubtype = TypeVar( - "LOBPCGDiagnosticsSubtype", bound="LOBPCGDiagnostics") + 'LOBPCGDiagnosticsSubtype', bound='LOBPCGDiagnostics' +) @struct.dataclass class LOBPCGDiagnostics: """Diagnostics for iterative LOBPCG eigenvalue routine. - Contains consistency error for LOBPCG eigenvalue routine, which - refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair - (v, lambda). This metics dataclass retains consistency error - and other useful LOBPCG values. - """ + Contains consistency error for LOBPCG eigenvalue routine, which + refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair + (v, lambda). This metics dataclass retains consistency error + and other useful LOBPCG values. + """ + lobpcg_iters: chex.Array = _default_zero_field() max_consistency_error: chex.Array = _default_zero_field() avg_consistency_error: chex.Array = _default_zero_field() @@ -248,7 +265,8 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters): mat_eigvecs = matrix.dot(eigvecs, precision=precision) consistency_error_unnormalized = jnp.linalg.norm( - mat_eigvecs - eigvals * eigvecs, ord=2, axis=0) + mat_eigvecs - eigvals * eigvecs, ord=2, axis=0 + ) normalization = jnp.linalg.norm(mat_eigvecs, ord=2, axis=0) + eigvals consistency_error = consistency_error_unnormalized / normalization @@ -256,20 +274,22 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters): orthogonality_error -= jnp.diag(jnp.diag(orthogonality_error)) return cls( - lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32), - max_consistency_error=jnp.max(consistency_error).astype(jnp.float32), - avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32), - avg_orthogonality_error=(jnp.sum(orthogonality_error) / - num_off_diag).astype(jnp.float32), - max_eigenvalue=jnp.max(eigvals).astype(jnp.float32), - min_eigenvalue=jnp.min(eigvals).astype(jnp.float32), - num_topk_eigenvectors=jnp.array(num_topk, jnp.float32), + lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32), + max_consistency_error=jnp.max(consistency_error).astype(jnp.float32), + avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32), + avg_orthogonality_error=( + jnp.sum(orthogonality_error) / num_off_diag + ).astype(jnp.float32), + max_eigenvalue=jnp.max(eigvals).astype(jnp.float32), + min_eigenvalue=jnp.min(eigvals).astype(jnp.float32), + num_topk_eigenvectors=jnp.array(num_topk, jnp.float32), ) @struct.dataclass class TrainingMetrics: """Diagnostic metrics from training.""" + # Error for inverse-pth roots. inverse_pth_root_errors: chex.Array = _default_zero_field() # Iteration count for inverse-pth roots. @@ -283,20 +303,24 @@ class TrainingMetrics: total_retries: chex.Array = _default_zero_field() lobpcg_diagnostics: LOBPCGDiagnostics = struct.field( - default_factory=LOBPCGDiagnostics) + default_factory=LOBPCGDiagnostics + ) # Rich matrix entrywise error diagnostics, if enabled. inverse_pth_root_diagnostics: InversePthRootDiagnostics = struct.field( - default_factory=InversePthRootDiagnostics) + default_factory=InversePthRootDiagnostics + ) # Diagnostics applied to the conditioned p-th root problem, after top # eigenvectors are removed, if LOBPCG is being applied. conditioned_inverse_pth_root_diagnostics: InversePthRootDiagnostics = ( - struct.field(default_factory=InversePthRootDiagnostics)) + struct.field(default_factory=InversePthRootDiagnostics) + ) # TODO(rohananil): Add more important metrics to track during training. # Per parameter optimizer state used in data-parallel training. class ParameterStats(NamedTuple): """State associated to each parameter of the model being trained.""" + diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner statistics: Optional[List[Any]] # Statistics (QuantizedValue, chex.Array) preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) @@ -321,12 +345,14 @@ class GlobalShardedParameterStats: @struct.dataclass class LocalShardedParameterStats: """State associated to each parameter of the model being trained.""" + diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner momentum: QuantizedValue # Momentum for the shampoo preconditioner training_metrics: Union[TrainingMetrics, optax.MaskedNode] index_start: Union[np.int32, int] = struct.field( - pytree_node=False) # Index into global statistics array + pytree_node=False + ) # Index into global statistics array sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. @@ -336,39 +362,44 @@ def default_training_metrics(): def init_training_metrics( - num_statistics, - generate_training_metrics, + num_statistics, + generate_training_metrics, ): """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() return jax.tree.map( - functools.partial(jnp.repeat, repeats=num_statistics), - default_training_metrics()) + functools.partial(jnp.repeat, repeats=num_statistics), + default_training_metrics(), + ) def init_training_metrics_shapes( - num_statistics, - generate_training_metrics, + num_statistics, + generate_training_metrics, ): """Initialize training metrics shape/dtype.""" seed = init_training_metrics( - num_statistics, - generate_training_metrics, + num_statistics, + generate_training_metrics, ) return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) -def init_training_metrics_pspec(generate_training_metrics,): +def init_training_metrics_pspec( + generate_training_metrics, +): """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), - default_training_metrics()) + return jax.tree.map( + lambda _: jax.sharding.PartitionSpec(), default_training_metrics() + ) class ShardedShampooStats(NamedTuple): """Shampoo state in sharded mode.""" + global_stats: Any local_stats: Any @@ -406,35 +437,35 @@ class PreconditionerType(enum.IntEnum): def power_iteration( - matrix, - num_iters=100, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - padding_start=None, + matrix, + num_iters=100, + error_tolerance=1e-6, + precision=lax.Precision.HIGHEST, + padding_start=None, ): r"""Power iteration algorithm. - The power iteration algorithm takes a symmetric PSD matrix `A`, and produces - a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue - of `A`, and a vector v, which is the corresponding eigenvector of `A`. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) - - Args: - matrix: the symmetric PSD matrix. - num_iters: Number of iterations. - error_tolerance: Iterative exit condition. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - padding_start: if set, assumes rows and columns after padding_start are - zero. - - Returns: - eigen vector, eigen value - """ + The power iteration algorithm takes a symmetric PSD matrix `A`, and produces + a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue + of `A`, and a vector v, which is the corresponding eigenvector of `A`. + + References: + [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) + + Args: + matrix: the symmetric PSD matrix. + num_iters: Number of iterations. + error_tolerance: Iterative exit condition. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + padding_start: if set, assumes rows and columns after padding_start are + zero. + + Returns: + eigen vector, eigen value + """ matrix_size = matrix.shape[-1] def _iter_condition(state): @@ -446,32 +477,38 @@ def _iter_body(state): i, new_v, s, s_v, unused_run_step = state new_v = new_v / jnp.linalg.norm(new_v) - s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision) - s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision) - return (i + 1, - s_v, - s_new, - s_v, - jnp.greater(jnp.abs(s_new - s), error_tolerance)) + s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) + s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) + return ( + i + 1, + s_v, + s_new, + s_v, + jnp.greater(jnp.abs(s_new - s), error_tolerance), + ) # Figure out how to use step as seed for random. - v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0, - matrix_size).astype(matrix.dtype) + v_0 = ( + np.random.RandomState(1729) + .uniform(-1.0, 1.0, matrix_size) + .astype(matrix.dtype) + ) v_0 = jnp.array(v_0) if padding_start is not None: - v_0 *= (jnp.arange(len(v_0), dtype=jnp.int32) < padding_start) + v_0 *= jnp.arange(len(v_0), dtype=jnp.int32) < padding_start init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) - _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, - init_state) + _, v_out, s_out, _, _ = lax.while_loop( + _iter_condition, _iter_body, init_state + ) v_out = v_out / jnp.linalg.norm(v_out) return v_out, s_out def mat_power( - mat_m, - p, - precision=lax.Precision.HIGHEST, + mat_m, + p, + precision=lax.Precision.HIGHEST, ): """A simple matrix power method. M^p where p can be TracedValue.""" power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE) @@ -483,9 +520,11 @@ def _iter_condition(state): def _iter_body(state): i, power, mat = state - power = jax.lax.cond(i % 2 == 1, - lambda: jnp.matmul(mat, power, precision=precision), - lambda: power) + power = jax.lax.cond( + i % 2 == 1, + lambda: jnp.matmul(mat, power, precision=precision), + lambda: power, + ) i //= 2 mat = jnp.matmul(mat, mat, precision=precision) return i, power, mat @@ -508,78 +547,81 @@ def _stable_subtract(b, a_minus_b): return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b)) return jnp.where( - # Choose the branch with the best log1p approximation. - jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a), - -_stable_subtract(a, -a_minus_b), - _stable_subtract(b, a_minus_b)) + # Choose the branch with the best log1p approximation. + jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a), + -_stable_subtract(a, -a_minus_b), + _stable_subtract(b, a_minus_b), + ) def matrix_inverse_pth_root( - matrix, - p, - num_iters=100, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - relative_matrix_epsilon=True, - lobpcg_topk_precondition=0, - lobpcg_max_iter=0, - padding_start=None, - prev=None, - eigh=False, + matrix, + p, + num_iters=100, + ridge_epsilon=1e-6, + error_tolerance=1e-6, + precision=lax.Precision.HIGHEST, + relative_matrix_epsilon=True, + lobpcg_topk_precondition=0, + lobpcg_max_iter=0, + padding_start=None, + prev=None, + eigh=False, ): """Computes `matrix^(-1/p)`, where `p` is a positive integer. - This function uses the Eigh or Coupled newton iterations algorithm for - the computation of a matrix's inverse pth root. - - - References: - [Functions of Matrices, Theory and Computation, - Nicholas J Higham, Pg 184, Eq 7.18]( - https://epubs.siam.org/doi/book/10.1137/1.9780898717778) - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - num_iters: Maximum number of iterations. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - lobpcg_topk_precondition: If nonzero, specifies the number of top - eigenvectors to subtract out before performing LOBPCG. Note this makes - relative_matrix_epsilon essentially free. - lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to - `lobpcg_topk_precondition`. - padding_start: If the input matrix was padded, then zeros out columns and - rows at the padding start. - prev: previous iteration's solution, zero-padded (unused) - eigh: If True, uses eigh for inverse-pth root computation. - - Returns: - `(matrix + eps)^(-1/p)` and error metrics. - - Note `eps` is not added to zeroed out padding rows and - columns. `eps` is just `ridge_epsilon` if - `relative_matrix_epsilon` is set to `False`, otherwise, it is the - ridge epsilon value scaled by the derived maximum eigenvalue of - the input matrix. - """ + This function uses the Eigh or Coupled newton iterations algorithm for + the computation of a matrix's inverse pth root. + + + References: + [Functions of Matrices, Theory and Computation, + Nicholas J Higham, Pg 184, Eq 7.18]( + https://epubs.siam.org/doi/book/10.1137/1.9780898717778) + + Args: + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + num_iters: Maximum number of iterations. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + relative_matrix_epsilon: Whether to use relative epsilon to the max eigen + value when computing inverse-pth root. + lobpcg_topk_precondition: If nonzero, specifies the number of top + eigenvectors to subtract out before performing LOBPCG. Note this makes + relative_matrix_epsilon essentially free. + lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to + `lobpcg_topk_precondition`. + padding_start: If the input matrix was padded, then zeros out columns and + rows at the padding start. + prev: previous iteration's solution, zero-padded (unused) + eigh: If True, uses eigh for inverse-pth root computation. + + Returns: + `(matrix + eps)^(-1/p)` and error metrics. + + Note `eps` is not added to zeroed out padding rows and + columns. `eps` is just `ridge_epsilon` if + `relative_matrix_epsilon` is set to `False`, otherwise, it is the + ridge epsilon value scaled by the derived maximum eigenvalue of + the input matrix. + """ if eigh: - return matrix_inverse_pth_root_eigh(matrix, - p, - ridge_epsilon, - error_tolerance, - precision, - relative_matrix_epsilon, - padding_start, - prev) + return matrix_inverse_pth_root_eigh( + matrix, + p, + ridge_epsilon, + error_tolerance, + precision, + relative_matrix_epsilon, + padding_start, + prev, + ) del prev assert matrix.shape[0] == matrix.shape[1] @@ -596,7 +638,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + matrix.dtype + ) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -607,18 +650,23 @@ def matrix_inverse_pth_root( eigvals, eigvecs, lobpcg_diagnostics = None, None, None if lobpcg_topk_precondition > 0: # TODO(vladf): reuse previous top-k as the initial search directions - pad_shape = (matrix_size - lobpcg_topk_precondition, - lobpcg_topk_precondition) + pad_shape = ( + matrix_size - lobpcg_topk_precondition, + lobpcg_topk_precondition, + ) search_dirs = jnp.concatenate( - (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0) + (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0 + ) eigvals, eigvecs, lobpcg_iters = linalg.lobpcg_standard( # pylint: disable=unbalanced-tuple-unpacking - matrix, search_dirs, - lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter) + matrix, + search_dirs, + lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter, + ) lobpcg_diagnostics = LOBPCGDiagnostics.create( - matrix, - eigvals, - eigvecs, - lobpcg_iters, + matrix, + eigvals, + eigvecs, + lobpcg_iters, ) # The minimal eigenvalue among top-k becomes the maximal one in the whole @@ -628,7 +676,8 @@ def matrix_inverse_pth_root( # Deflate out top eigenvectors to reduce matrix condition number. matrix -= scaled_vecs.dot( - scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) + scaled_vecs.T, precision=jax.lax.Precision.HIGHEST + ) if relative_matrix_epsilon: if eigvals is not None: @@ -637,11 +686,12 @@ def matrix_inverse_pth_root( # Only use power iteration if lobpcg wasn't already used to derive the # top eigenvalue. _, max_ev = power_iteration( - matrix=matrix, - num_iters=100, - error_tolerance=1e-6, - precision=precision, - padding_start=padding_start) + matrix=matrix, + num_iters=100, + error_tolerance=1e-6, + precision=precision, + padding_start=padding_start, + ) else: # Use absolute matrix epsilon scaling otherwise. max_ev = 1.0 @@ -654,8 +704,9 @@ def matrix_inverse_pth_root( def _iter_condition(state): i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, error_ratio = state - error_above_threshold = jnp.logical_and(error > error_tolerance, - error_ratio < max_error_ratio) + error_above_threshold = jnp.logical_and( + error > error_tolerance, error_ratio < max_error_ratio + ) return jnp.logical_and(i < num_iters, error_above_threshold) def _iter_body(state): @@ -673,7 +724,6 @@ def _iter_body(state): iters = 0 error_ratio = 0.0 else: - retry_loop_error_threshold = 0.05 num_tries = 6 init_outer_state = tuple([0, identity, 1000.0, 100, 1.0, True]) @@ -691,23 +741,26 @@ def _outer_body_fn(state): new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * jnp.power(z, 1.0 / p) init_state = tuple( - [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0]) + [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0] + ) iters, mat_m, mat_h, old_mat_h, error, error_ratio = lax.while_loop( - _iter_condition, _iter_body, init_state) + _iter_condition, _iter_body, init_state + ) error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32) is_converged = jnp.asarray(error_ratio < max_error_ratio, old_mat_h.dtype) - resultant_mat_h = is_converged * \ - mat_h + (1 - is_converged) * old_mat_h - return (i + 1, - resultant_mat_h, - error, - iters, - error_ratio, - error > retry_loop_error_threshold) - - loop_outputs = jax.lax.while_loop(_outer_iter_condition_fn, - _outer_body_fn, - init_outer_state) + resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h + return ( + i + 1, + resultant_mat_h, + error, + iters, + error_ratio, + error > retry_loop_error_threshold, + ) + + loop_outputs = jax.lax.while_loop( + _outer_iter_condition_fn, _outer_body_fn, init_outer_state + ) total_retries, resultant_mat_h, error, iters, error_ratio, _ = loop_outputs conditioned_resultant_mat = resultant_mat_h @@ -723,35 +776,39 @@ def _outer_body_fn(state): pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p) scaled_vecs = eigvecs * jnp.sqrt(pth_diff) resultant_mat_h = conditioned_resultant_mat - scaled_vecs.dot( - scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) + scaled_vecs.T, precision=jax.lax.Precision.HIGHEST + ) error_metrics = TrainingMetrics( - inverse_pth_root_errors=jnp.array(error, jnp.float32), - inverse_pth_root_iters=jnp.array(iters, jnp.float32), - final_error_ratio=jnp.array(error_ratio, jnp.float32), - max_eigen_value=jnp.array(max_ev, jnp.float32), - total_retries=jnp.array(total_retries, jnp.float32)) + inverse_pth_root_errors=jnp.array(error, jnp.float32), + inverse_pth_root_iters=jnp.array(iters, jnp.float32), + final_error_ratio=jnp.array(error_ratio, jnp.float32), + max_eigen_value=jnp.array(max_ev, jnp.float32), + total_retries=jnp.array(total_retries, jnp.float32), + ) if lobpcg_topk_precondition > 0: - damped_matrix = matrix + \ - (ridge_epsilon * (10**total_retries) * identity) + damped_matrix = matrix + (ridge_epsilon * (10**total_retries) * identity) conditioned_diagnostics = InversePthRootDiagnostics.create( - conditioned_resultant_mat, damped_matrix, p) + conditioned_resultant_mat, damped_matrix, p + ) unconditioned_damped_matrix = original_matrix + ridge_epsilon * identity unconditioned_diagnostics = InversePthRootDiagnostics.create( - resultant_mat_h, unconditioned_damped_matrix, p) + resultant_mat_h, unconditioned_damped_matrix, p + ) # The max entrywise error in error_metrics.inverse_pth_root_errors refers # to what was derived from the inverse pth root iteration, which with # LOBPCG refers to the conditioned problem. Make sure to use the error # from the unconditioned problem. unconditional_errors = jnp.maximum( - unconditioned_diagnostics.max_diag_error, - unconditioned_diagnostics.max_off_diag_error) + unconditioned_diagnostics.max_diag_error, + unconditioned_diagnostics.max_off_diag_error, + ) error_metrics = error_metrics.replace( - inverse_pth_root_errors=unconditional_errors, - lobpcg_diagnostics=lobpcg_diagnostics, - conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics, - inverse_pth_root_diagnostics=unconditioned_diagnostics, + inverse_pth_root_errors=unconditional_errors, + lobpcg_diagnostics=lobpcg_diagnostics, + conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics, + inverse_pth_root_diagnostics=unconditioned_diagnostics, ) if padding_start is not None: @@ -759,9 +816,9 @@ def _outer_body_fn(state): # due to some TPU hosts not having the same number of preconditioning # matrices. resultant_mat_h = jnp.where(padding_start == 0, 0.0, resultant_mat_h) - error = jnp.where(padding_start == 0, - 0.0, - error_metrics.inverse_pth_root_errors) + error = jnp.where( + padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors + ) error_metrics = error_metrics.replace(inverse_pth_root_errors=error) resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype) @@ -769,44 +826,44 @@ def _outer_body_fn(state): def matrix_inverse_pth_root_eigh( - matrix, - p, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - relative_matrix_epsilon=True, - padding_start=None, - prev=None, + matrix, + p, + ridge_epsilon=1e-6, + error_tolerance=1e-6, + precision=lax.Precision.HIGHEST, + relative_matrix_epsilon=True, + padding_start=None, + prev=None, ): """Computes `matrix^(-1/p)`, where `p` is a positive integer. - This function uses eigh for the computation of a matrix's inverse pth - root. - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - padding_start: If the input matrix was padded, then zeros out columns and - rows at the padding start. - prev: previous iteration's solution, zero-padded (unused) - - Returns: - `(matrix + eps)^(-1/p)` and error metrics. - - Note `eps` is not added to zeroed out padding rows and - columns. `eps` is just `ridge_epsilon` if - `relative_matrix_epsilon` is set to `False`, otherwise, it is the - ridge epsilon value scaled by the derived maximum eigenvalue of - the input matrix. - """ + This function uses eigh for the computation of a matrix's inverse pth + root. + + Args: + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + relative_matrix_epsilon: Whether to use relative epsilon to the max eigen + value when computing inverse-pth root. + padding_start: If the input matrix was padded, then zeros out columns and + rows at the padding start. + prev: previous iteration's solution, zero-padded (unused) + + Returns: + `(matrix + eps)^(-1/p)` and error metrics. + + Note `eps` is not added to zeroed out padding rows and + columns. `eps` is just `ridge_epsilon` if + `relative_matrix_epsilon` is set to `False`, otherwise, it is the + ridge epsilon value scaled by the derived maximum eigenvalue of + the input matrix. + """ del prev assert matrix.shape[0] == matrix.shape[1] matrix_size = matrix.shape[0] @@ -816,17 +873,19 @@ def matrix_inverse_pth_root_eigh( identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + matrix.dtype + ) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix if relative_matrix_epsilon: _, max_ev = power_iteration( - matrix=matrix, - num_iters=100, - error_tolerance=error_tolerance, - precision=precision, - padding_start=padding_start) + matrix=matrix, + num_iters=100, + error_tolerance=error_tolerance, + precision=precision, + padding_start=padding_start, + ) else: # Use absolute matrix epsilon scaling otherwise. max_ev = 1.0 @@ -837,9 +896,9 @@ def matrix_inverse_pth_root_eigh( if padding_start is not None: e *= jnp.flip(ix) mm = functools.partial(jnp.matmul, precision=precision) - inv_e = jnp.where(e == 0.0, - 0.0, - jnp.power(jnp.maximum(e, ridge_epsilon), alpha)) + inv_e = jnp.where( + e == 0.0, 0.0, jnp.power(jnp.maximum(e, ridge_epsilon), alpha) + ) val = mm(mm(u, jnp.diag(inv_e)), u.T) root = u * jnp.sqrt(inv_e) val = mm(root, root.T) @@ -849,12 +908,13 @@ def matrix_inverse_pth_root_eigh( eig_error *= jnp.flip(ix) error = jnp.max(jnp.abs(eig_error)) error_metrics = TrainingMetrics( - inverse_pth_root_errors=jnp.array(error, jnp.float32)) + inverse_pth_root_errors=jnp.array(error, jnp.float32) + ) if padding_start is not None: val = jnp.where(padding_start == 0, 0.0, val) - error = jnp.where(padding_start == 0, - 0.0, - error_metrics.inverse_pth_root_errors) + error = jnp.where( + padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors + ) error_metrics = error_metrics.replace(inverse_pth_root_errors=error) val = jnp.asarray(val, orig_dtype) return val, error_metrics @@ -863,17 +923,17 @@ def matrix_inverse_pth_root_eigh( def merge_small_dims(shape_to_merge, max_dim): """Merge small dimensions. - If there are some small dimensions, we collapse them: - e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 - [1, 2, 768, 1, 2048] --> [2, 768, 2048] + If there are some small dimensions, we collapse them: + e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 + [1, 2, 768, 1, 2048] --> [2, 768, 2048] - Args: - shape_to_merge: Shape to merge small dimensions. - max_dim: Maximal dimension of output shape used in merging. + Args: + shape_to_merge: Shape to merge small dimensions. + max_dim: Maximal dimension of output shape used in merging. - Returns: - Merged shape. - """ + Returns: + Merged shape. + """ if shape_to_merge and np.all(np.array(shape_to_merge) == 1): return [1] @@ -894,20 +954,23 @@ def merge_small_dims(shape_to_merge, max_dim): def pad_square_matrix(mat, max_size): """Pad a square matrix up to max_size. - Args: - mat: a matrix to pad. - max_size: matrix size requested. + Args: + mat: a matrix to pad. + max_size: matrix size requested. - Returns: - Given M returns [[M, 0], [0, I]] - """ + Returns: + Given M returns [[M, 0], [0, I]] + """ rows, cols = mat.shape if rows != cols: - raise ValueError("Must have rows == cols, instead got " - f"rows={rows}, cols={cols}") + raise ValueError( + f'Must have rows == cols, instead got rows={rows}, cols={cols}' + ) if cols > max_size: - raise ValueError("Must have cols <= max_size. Instead got " - f"cols={cols}, max_size={max_size}.") + raise ValueError( + 'Must have cols <= max_size. Instead got ' + f'cols={cols}, max_size={max_size}.' + ) if rows == max_size: return mat pad_size = max_size - rows @@ -923,13 +986,13 @@ def pad_square_matrix(mat, max_size): def pad_vector(vec, max_size): """Pad a vector to a max_size. - Args: - vec: a vector to pad. - max_size: matrix size requested. + Args: + vec: a vector to pad. + max_size: matrix size requested. - Returns: - Given V returns [V, 0] - """ + Returns: + Given V returns [V, 0] + """ size = vec.shape[0] assert size <= max_size if size == max_size: @@ -949,9 +1012,9 @@ def _iter_body(unused_state): def _iter_condition(state): return state[0] - results = jax.lax.while_loop(_iter_condition, - _iter_body, - tuple([predicate] + init_state)) + results = jax.lax.while_loop( + _iter_condition, _iter_body, tuple([predicate] + init_state) + ) return tuple(results[1:]) @@ -985,7 +1048,7 @@ def partition(self, tensor): assert tensor.shape == self._shape tensors = [tensor] - for (i, indices) in self._splits: + for i, indices in self._splits: tensors_local = [] for t in tensors: tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) @@ -995,13 +1058,14 @@ def partition(self, tensor): def merge_partitions(self, partitions): """Merge partitions back to original shape.""" - for (i, indices) in reversed(self._splits): + for i, indices in reversed(self._splits): n = len(indices) + 1 partial_merged_tensors = [] ind = 0 while ind < len(partitions): partial_merged_tensors.append( - jnp.concatenate(partitions[ind:ind + n], axis=i)) + jnp.concatenate(partitions[ind : ind + n], axis=i) + ) ind += n partitions = partial_merged_tensors assert len(partitions) == 1 @@ -1011,25 +1075,25 @@ def merge_partitions(self, partitions): def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None): """Updated statistics via weighted average with new Gram matrix. - Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose - columns are the flattened slices of the tensor `g` along the given `axis`. - (So, `old_stats` and the returned matrix have dimensions n x n where - n = `g.shape[axis]`). - - Args: - old_stats: Old statistics. - g: Gradient tensor. - axis: Axis along which to slice `g`. - w1: Scalar weight for old statistics. - w2: Scalar weight for new Gram matrix. - precision: Optional precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - Weighted average of old and new statistics. - """ + Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose + columns are the flattened slices of the tensor `g` along the given `axis`. + (So, `old_stats` and the returned matrix have dimensions n x n where + n = `g.shape[axis]`). + + Args: + old_stats: Old statistics. + g: Gradient tensor. + axis: Axis along which to slice `g`. + w1: Scalar weight for old statistics. + w2: Scalar weight for new Gram matrix. + precision: Optional precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + + Returns: + Weighted average of old and new statistics. + """ axes = [i for i in range(g.ndim) if i != axis] gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision) return w1 * old_stats + w2 * gram_matrix @@ -1039,67 +1103,68 @@ class Preconditioner: """Compute statistics/shape from gradients for preconditioning.""" def __init__( - self, - param, - block_size, - merge_small_dims_block_size, - best_effort_shape_interpretation, - preconditioner_type=PreconditionerType.ALL, + self, + param, + block_size, + merge_small_dims_block_size, + best_effort_shape_interpretation, + preconditioner_type=PreconditionerType.ALL, ): """Initializes the preconditioner. - Args: - param: parameter to precondition. - block_size: Block size used to split param. - merge_small_dims_block_size: Block size for merging dims. - best_effort_shape_interpretation: Whether to - collapse/merge dims together. - preconditioner_type: Type of preconditioner to use. - """ + Args: + param: parameter to precondition. + block_size: Block size used to split param. + merge_small_dims_block_size: Block size for merging dims. + best_effort_shape_interpretation: Whether to + collapse/merge dims together. + preconditioner_type: Type of preconditioner to use. + """ self._original_shape = param.shape self._transformed_shape = param.shape if best_effort_shape_interpretation: - self._transformed_shape = merge_small_dims(self._original_shape, - merge_small_dims_block_size) + self._transformed_shape = merge_small_dims( + self._original_shape, merge_small_dims_block_size + ) reshaped_param = jnp.reshape(param, self._transformed_shape) self._partitioner = BlockPartitioner(reshaped_param, block_size) self._preconditioner_type = preconditioner_type def updated_statistics_from_grad( - self, - stats, - grad, - w1, - w2, - to_float=None, - from_float=None, - precision=None, + self, + stats, + grad, + w1, + w2, + to_float=None, + from_float=None, + precision=None, ): """Update statistics from gradients. - Args: - stats: Old statistics or its Cholesky factor if `cholesky` is True. - grad: Gradient to compute statistics from. - w1: Weight for old statistics. - w2: Weight for new statistics. - to_float: Optional function for converting stats to floating point. - from_float: Optional function for converting from floating point. - precision: Optional precision XLA related flag, the available options - are: - a) lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - A list of updated gradient statistics for each partition. - """ + Args: + stats: Old statistics or its Cholesky factor if `cholesky` is True. + grad: Gradient to compute statistics from. + w1: Weight for old statistics. + w2: Weight for new statistics. + to_float: Optional function for converting stats to floating point. + from_float: Optional function for converting from floating point. + precision: Optional precision XLA related flag, the available options + are: + a) lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + + Returns: + A list of updated gradient statistics for each partition. + """ to_float = to_float if to_float is not None else (lambda x: x) from_float = from_float if from_float is not None else (lambda x: x) reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) should_preconditioned_dims = self.should_precondition_dims() preconditioned_dims = [ - i for i, p in enumerate(should_preconditioned_dims) if p + i for i, p in enumerate(should_preconditioned_dims) if p ] new_stats = [] index = 0 @@ -1136,8 +1201,7 @@ def _preconds_for_grad(self, preconditioners, rank, start, end): elif self._preconditioner_type == PreconditionerType.OUTPUT: # When _preconditioner_type is OUTPUT, we append (rank - 1) many None # values to the beginning of the list to handle the False indices. - preconditioners_for_grad = [None] * \ - (rank - 1) + preconditioners_for_grad + preconditioners_for_grad = [None] * (rank - 1) + preconditioners_for_grad assert len(preconditioners_for_grad) == rank return preconditioners_for_grad @@ -1165,13 +1229,13 @@ def exponent_for_preconditioner(self): def preconditioned_grad(self, grad, preconditioners): """Precondition the gradient. - Args: - grad: A gradient tensor to precondition. - preconditioners: A list of preconditioners to apply. + Args: + grad: A gradient tensor to precondition. + preconditioners: A list of preconditioners to apply. - Returns: - A preconditioned gradient. - """ + Returns: + A preconditioned gradient. + """ reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) should_preconditioned_dims = self.should_precondition_dims() @@ -1179,17 +1243,18 @@ def preconditioned_grad(self, grad, preconditioners): preconditioned_partitioned_grads = [] for i, g in enumerate(partitioned_grads): preconditioners_for_grad = self._preconds_for_grad( - preconditioners, - rank=len(should_preconditioned_dims), - start=i * num_preconditioners, - end=(i + 1) * num_preconditioners, + preconditioners, + rank=len(should_preconditioned_dims), + start=i * num_preconditioners, + end=(i + 1) * num_preconditioners, + ) + precond_g = self._precondition_block( + g, should_preconditioned_dims, preconditioners_for_grad ) - precond_g = self._precondition_block(g, - should_preconditioned_dims, - preconditioners_for_grad) preconditioned_partitioned_grads.append(precond_g) merged_grad = self._partitioner.merge_partitions( - preconditioned_partitioned_grads) + preconditioned_partitioned_grads + ) return jnp.reshape(merged_grad, self._original_shape) def _precondition_block(self, g, should_precondition_dim, preconditioners): @@ -1208,9 +1273,9 @@ def _precondition_block(self, g, should_precondition_dim, preconditioners): return g -def _convert_to_parameter_stats(global_stats, - local_stat, - convert_statistics=True): +def _convert_to_parameter_stats( + global_stats, local_stat, convert_statistics=True +): """Creates parameter stats from sharded stats.""" index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start @@ -1225,24 +1290,24 @@ def _convert_to_parameter_stats(global_stats, if not convert_statistics: new_statistics = None return ParameterStats( - local_stat.diagonal_statistics, - new_statistics, - new_preconditioners, - local_stat.diagonal_momentum, - local_stat.momentum, - local_stat.training_metrics, + local_stat.diagonal_statistics, + new_statistics, + new_preconditioners, + local_stat.diagonal_momentum, + local_stat.momentum, + local_stat.training_metrics, ) def _convert_from_parameter_stats(parameter_stats, local_stats): """Creates sharded stats from paramter stats.""" return LocalShardedParameterStats( - parameter_stats.diagonal_statistics, - parameter_stats.diagonal_momentum, - parameter_stats.momentum, - parameter_stats.training_metrics, - local_stats.index_start, - local_stats.sizes, + parameter_stats.diagonal_statistics, + parameter_stats.diagonal_momentum, + parameter_stats.momentum, + parameter_stats.training_metrics, + local_stats.index_start, + local_stats.sizes, ) @@ -1258,12 +1323,13 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured # values). - per_stat_metrics = efficient_cond(keep_old, - lambda: [local_stat.training_metrics], - [per_stat_metrics])[0] + per_stat_metrics = efficient_cond( + keep_old, lambda: [local_stat.training_metrics], [per_stat_metrics] + )[0] # pylint:enable=cell-var-from-loop new_local_stats.append( - local_stat.replace(training_metrics=per_stat_metrics)) + local_stat.replace(training_metrics=per_stat_metrics) + ) return new_local_stats @@ -1271,7 +1337,7 @@ def batch(x, num_devices): """Batch `x` so that so that leading axis is num_devices.""" n = len(x) b = int(n / num_devices) - return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)]) + return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)]) def unbatch(batched_values): @@ -1290,162 +1356,168 @@ def unbatch(batched_values): def distributed_shampoo( - learning_rate, - block_size=1024, - beta1=0.9, - beta2=0.999, - diagonal_epsilon=1e-8, - matrix_epsilon=1e-6, - weight_decay=0.0, - start_preconditioning_step=101, - preconditioning_compute_steps=20, - statistics_compute_steps=1, - best_effort_shape_interpretation=True, - graft_type=GraftingType.RMSPROP_NORMALIZED, - nesterov=True, - exponent_override=0, - # Pass pmap 'batch axis name' in pmap mode. - batch_axis_name=None, - # Only set following 3 params in pjit/spmd mode. - # WARNING: Experimental - statistics_partition_spec=None, - preconditioner_partition_spec=None, - num_devices_for_pjit=None, - shard_optimizer_states=False, - ### - # Experimental memory reduction mode - best_effort_memory_usage_reduction=True, - ### - inverse_failure_threshold=0.1, - moving_average_for_momentum=True, - skip_preconditioning_dim_size_gt=0, - clip_by_scaled_gradient_norm=None, - precision=lax.Precision.HIGHEST, - tensordot_precision=None, - relative_matrix_epsilon=True, - merge_small_dims_block_size=4096, - lobpcg_topk_precondition=0, - lobpcg_max_iter=0, - precondtioner_type=PreconditionerType.ALL, - custom_preconditioner=False, - skip_preconditioning_rank_lt=1, - decoupled_learning_rate=True, - decoupled_weight_decay=False, - generate_training_metrics=True, - reuse_preconditioner=False, - eigh=True, + learning_rate, + block_size=1024, + beta1=0.9, + beta2=0.999, + diagonal_epsilon=1e-8, + matrix_epsilon=1e-6, + weight_decay=0.0, + start_preconditioning_step=101, + preconditioning_compute_steps=20, + statistics_compute_steps=1, + best_effort_shape_interpretation=True, + graft_type=GraftingType.RMSPROP_NORMALIZED, + nesterov=True, + exponent_override=0, + # Pass pmap 'batch axis name' in pmap mode. + batch_axis_name=None, + # Only set following 3 params in pjit/spmd mode. + # WARNING: Experimental + statistics_partition_spec=None, + preconditioner_partition_spec=None, + num_devices_for_pjit=None, + shard_optimizer_states=False, + ### + # Experimental memory reduction mode + best_effort_memory_usage_reduction=True, + ### + inverse_failure_threshold=0.1, + moving_average_for_momentum=True, + skip_preconditioning_dim_size_gt=0, + clip_by_scaled_gradient_norm=None, + precision=lax.Precision.HIGHEST, + tensordot_precision=None, + relative_matrix_epsilon=True, + merge_small_dims_block_size=4096, + lobpcg_topk_precondition=0, + lobpcg_max_iter=0, + precondtioner_type=PreconditionerType.ALL, + custom_preconditioner=False, + skip_preconditioning_rank_lt=1, + decoupled_learning_rate=True, + decoupled_weight_decay=False, + generate_training_metrics=True, + reuse_preconditioner=False, + eigh=True, ): """Distributed Shampoo optimizer. - Distributed Shampoo is a second-order preconditioned method (concretely, a - variant of full-matrix Adagrad), that provides significant convergence and - wall-clock time improvements compared to conventional first-order methods, - and that has been shown to scale to large state-of-the-art deep learning - models. - - References: - Scalable Second Order Optimization for Deep Learning, - Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer - - Preprint: https://arxiv.org/abs/2002.09018 - - Args: - learning_rate: the step size used to update the parameters. - block_size: Block size for large layers (if > 0). Preconditioning compute - operation is cubic in the dimension of the tensor. Block size allows us - to chunk the layers into sub-layers of maximal dimension dictated by - this value. Use 128 as default (increase if you have compute budget). - beta1: momentum parameter. - beta2: second moment averaging parameter. - diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting - to AdaGrad is enabled). - matrix_epsilon: epsilon to add to statistics before computing inverse pth - root. If you are running in f32 precision for inverse pth root - (recommended today) this can go upto 1e-6. If you have latest hardware - with native f64 precision, set this upto 1e-12. - weight_decay: Weight decay for regularization. - start_preconditioning_step: When to start Shampoo update before which - diagonal update is used. This is because we dont have enough information - to do stable inverse. - preconditioning_compute_steps: How often to compute preconditioner. - Performance tuning params for controlling memory and compute - requirements. - Ideally set this and statistics_compute_steps params to 1. - statistics_compute_steps: How often to compute statistics. - best_effort_shape_interpretation: If there are some small dimensions, - collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if - block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] - graft_type: Grafting is a technique to fix the layerwise scale of Shampoo - optimizer. This allows us to plugin the Shampoo optimizer into settings - where SGD/AdaGrad is already well tuned. - nesterov: Nesterov momentum. - exponent_override: Override the exponent used in matrix inverse. - batch_axis_name: labeled axis over pmap for data-parallel training the - optimizer used for. - statistics_partition_spec: PartitionSpec to be used in sharded mode. - preconditioner_partition_spec: PartitionSpec to be used in sharded mode. - num_devices_for_pjit: Number of devices to parallelize over when using - pjit. - shard_optimizer_states: Shard optimizer states to save memory in model - parallel training. - best_effort_memory_usage_reduction: Best effort memory usage reduction. - - diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) - -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals - inverse_failure_threshold: numerics are hard and inverses fail sometimes; - we determine that using this threshold. - moving_average_for_momentum: Whether to use moving average for momentum - instead of exponential moving average. - skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is - greater than this value. - clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful - when using RMSProp Grafting). - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - tensordot_precision: Optional precision to use for the tensordot operation - when computing statistics (e.g., G Gᵀ). Same options as `precision` - above. - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - merge_small_dims_block_size: Used as the maximum block size to merge the - shapes. - lobpcg_topk_precondition: If nonzero, specifies the number of top - eigenvectors to subtract out before performing LOBPCG. Note this makes - relative_matrix_epsilon essentially free. - lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to - `lobpcg_topk_precondition`. - precondtioner_type: Preconditioner type to select all, left only or right - only preconditioners. - skip_preconditioning_rank_lt: Skips preconditioning for parameters with - rank less than this value. - decoupled_learning_rate: If True, use decoupled learning rate, otherwise - couple it with preconditioned gradient computation. (Default True) - decoupled_weight_decay: If True, use decoupled weight decay, otherwise - couple with weight decay. (Default False) - generate_training_metrics: If True, gather training metrics, otherwise - avoid generating them (to reduce memory usage). - reuse_preconditioner: If True, pass the previous derived preconditioner - as a warm start to the next iteratin's inverse pth root computation. - eigh: If True, and uses eigen decomposition for inverse-pth root. - - Returns: - a GradientTransformation. - """ + Distributed Shampoo is a second-order preconditioned method (concretely, a + variant of full-matrix Adagrad), that provides significant convergence and + wall-clock time improvements compared to conventional first-order methods, + and that has been shown to scale to large state-of-the-art deep learning + models. + + References: + Scalable Second Order Optimization for Deep Learning, + Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer + + Preprint: https://arxiv.org/abs/2002.09018 + + Args: + learning_rate: the step size used to update the parameters. + block_size: Block size for large layers (if > 0). Preconditioning compute + operation is cubic in the dimension of the tensor. Block size allows us + to chunk the layers into sub-layers of maximal dimension dictated by + this value. Use 128 as default (increase if you have compute budget). + beta1: momentum parameter. + beta2: second moment averaging parameter. + diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting + to AdaGrad is enabled). + matrix_epsilon: epsilon to add to statistics before computing inverse pth + root. If you are running in f32 precision for inverse pth root + (recommended today) this can go upto 1e-6. If you have latest hardware + with native f64 precision, set this upto 1e-12. + weight_decay: Weight decay for regularization. + start_preconditioning_step: When to start Shampoo update before which + diagonal update is used. This is because we dont have enough information + to do stable inverse. + preconditioning_compute_steps: How often to compute preconditioner. + Performance tuning params for controlling memory and compute + requirements. + Ideally set this and statistics_compute_steps params to 1. + statistics_compute_steps: How often to compute statistics. + best_effort_shape_interpretation: If there are some small dimensions, + collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if + block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] + graft_type: Grafting is a technique to fix the layerwise scale of Shampoo + optimizer. This allows us to plugin the Shampoo optimizer into settings + where SGD/AdaGrad is already well tuned. + nesterov: Nesterov momentum. + exponent_override: Override the exponent used in matrix inverse. + batch_axis_name: labeled axis over pmap for data-parallel training the + optimizer used for. + statistics_partition_spec: PartitionSpec to be used in sharded mode. + preconditioner_partition_spec: PartitionSpec to be used in sharded mode. + num_devices_for_pjit: Number of devices to parallelize over when using + pjit. + shard_optimizer_states: Shard optimizer states to save memory in model + parallel training. + best_effort_memory_usage_reduction: Best effort memory usage reduction. - + diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) + -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals + inverse_failure_threshold: numerics are hard and inverses fail sometimes; + we determine that using this threshold. + moving_average_for_momentum: Whether to use moving average for momentum + instead of exponential moving average. + skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is + greater than this value. + clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful + when using RMSProp Grafting). + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise) b) + lax.Precision.HIGH (increased precision, slower) c) + lax.Precision.HIGHEST (best possible precision, slowest) + tensordot_precision: Optional precision to use for the tensordot operation + when computing statistics (e.g., G Gᵀ). Same options as `precision` + above. + relative_matrix_epsilon: Whether to use relative epsilon to the max eigen + value when computing inverse-pth root. + merge_small_dims_block_size: Used as the maximum block size to merge the + shapes. + lobpcg_topk_precondition: If nonzero, specifies the number of top + eigenvectors to subtract out before performing LOBPCG. Note this makes + relative_matrix_epsilon essentially free. + lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to + `lobpcg_topk_precondition`. + precondtioner_type: Preconditioner type to select all, left only or right + only preconditioners. + skip_preconditioning_rank_lt: Skips preconditioning for parameters with + rank less than this value. + decoupled_learning_rate: If True, use decoupled learning rate, otherwise + couple it with preconditioned gradient computation. (Default True) + decoupled_weight_decay: If True, use decoupled weight decay, otherwise + couple with weight decay. (Default False) + generate_training_metrics: If True, gather training metrics, otherwise + avoid generating them (to reduce memory usage). + reuse_preconditioner: If True, pass the previous derived preconditioner + as a warm start to the next iteratin's inverse pth root computation. + eigh: If True, and uses eigen decomposition for inverse-pth root. + + Returns: + a GradientTransformation. + """ reset_frequency = None def _graft_type_has_diagonal_statistics(): """Returns True if using diagonal firt order method for grafting.""" return graft_type not in [ - GraftingType.SGD, GraftingType.SQRT_N, GraftingType.NONE + GraftingType.SGD, + GraftingType.SQRT_N, + GraftingType.NONE, ] def quantized_dtype_for_momentum_buffers(var): - return jnp.int8 if best_effort_memory_usage_reduction and len( - var.shape) > 1 else jnp.float32 + return ( + jnp.int8 + if best_effort_memory_usage_reduction and len(var.shape) > 1 + else jnp.float32 + ) quantize_second_moment = ( - best_effort_memory_usage_reduction and batch_axis_name) + best_effort_memory_usage_reduction and batch_axis_name + ) # Preconditioner and statistics are both stores as int16 in this mode. # We take out the diagonal to make quantization easier. @@ -1472,19 +1544,20 @@ def _to_float(maybe_quantized): def preconditioner_from_params(param): """Returns a Preconditioner object for given param.""" return Preconditioner( - param, - block_size, - merge_small_dims_block_size, - best_effort_shape_interpretation, - precondtioner_type, + param, + block_size, + merge_small_dims_block_size, + best_effort_shape_interpretation, + precondtioner_type, ) def precond_dim(max_size): """Derives largest preconditioner dimension.""" return max_size - def pad_and_maybe_zero_preconditioners(preconditioners, total, max_size, - step): + def pad_and_maybe_zero_preconditioners( + preconditioners, total, max_size, step + ): """Pad preconditioners up to total x max_size x precond_dim(max_size).""" pd = precond_dim(max_size) @@ -1513,9 +1586,9 @@ def _pad_preconditioner(preconditioner): def sharded_init_fn(params): """Returns optimizer state (for PJIT mode). - Args: - params: the parameters that should be updated. - """ + Args: + params: the parameters that should be updated. + """ params_flat, treedef = jax.tree_util.tree_flatten(params) # Find max size to pad to. max_size = 0 @@ -1542,21 +1615,22 @@ def sharded_init_fn(params): sizes = [s[0] for s in shapes] shapes = preconditioner.shapes_for_preconditioners() statistics = [ - matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) - for s in shapes + matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) for s in shapes ] pd = precond_dim(max_size) # If the preconditioner is using a low-rank representation, initialize # it to zero instead of an invalid eye. preconditioners = [ - jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size) - for s in shapes + jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size) + for s in shapes ] padded_statistics.extend(statistics) padded_preconditioners.extend(preconditioners) exponent = ( - preconditioner.exponent_for_preconditioner() - if exponent_override == 0 else exponent_override) + preconditioner.exponent_for_preconditioner() + if exponent_override == 0 + else exponent_override + ) exponents.extend([exponent] * len(shapes)) diagonal_statistics = jnp.zeros_like(param) @@ -1564,16 +1638,18 @@ def sharded_init_fn(params): momentum = jnp.zeros_like(param) local_stats_flat.append( - LocalShardedParameterStats( - diagonal_statistics, - diagonal_momentum, - momentum, - init_training_metrics( - len(sizes), - generate_training_metrics, - ), - index_start, - sizes)) + LocalShardedParameterStats( + diagonal_statistics, + diagonal_momentum, + momentum, + init_training_metrics( + len(sizes), + generate_training_metrics, + ), + index_start, + sizes, + ) + ) local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) to_pad = -len(padded_statistics) % num_devices_for_pjit @@ -1588,22 +1664,27 @@ def sharded_init_fn(params): # TODO(rohananil): Relax to only the size of the mesh axis where the dim # is split on. padded_statistics.extend( - [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]) + [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] + ) pd = precond_dim(max_size) # If the preconditioner is using a low-rank representation, initialize # it to zero instead of an invalid eye. - padded_preconditioners.extend([ + padded_preconditioners.extend( + [ jnp.eye(max_size, pd, dtype=stat_dtype) * (pd == max_size) for _ in range(to_pad) - ]) + ] + ) exponents.extend([1 for _ in range(to_pad)]) global_stats = GlobalShardedParameterStats( - jnp.stack(padded_statistics), - jnp.stack(padded_preconditioners), - jnp.stack(exponents)) + jnp.stack(padded_statistics), + jnp.stack(padded_preconditioners), + jnp.stack(exponents), + ) return ShampooState( - count=jnp.zeros([], jnp.int32), - stats=ShardedShampooStats(global_stats, local_stats)) + count=jnp.zeros([], jnp.int32), + stats=ShardedShampooStats(global_stats, local_stats), + ) def _max_statistics_size_from_params(params): max_size = 0 @@ -1624,20 +1705,21 @@ def _remove_leading_sharding_annotation(pspec): else: return [] - def sharded_init_partition_spec_fn(params, - params_partition_spec, - partition_spec_for_statistics): + def sharded_init_partition_spec_fn( + params, params_partition_spec, partition_spec_for_statistics + ): """Returns a parallel state tree with PartitionSpec associated with state. - Args: - params: A pytree with params. - params_partition_spec: A pytree with PartitionSpec for params. - partition_spec_for_statistics: PartitionSpec for the statistics. - """ + Args: + params: A pytree with params. + params_partition_spec: A pytree with PartitionSpec for params. + partition_spec_for_statistics: PartitionSpec for the statistics. + """ # Parallel lists of spec, and params. param_pspec_flat, _ = jax.tree_util.tree_flatten( - params_partition_spec, is_leaf=lambda x: x is None) + params_partition_spec, is_leaf=lambda x: x is None + ) params_flat, treedef = jax.tree_util.tree_flatten(params) assert param_pspec_flat assert params_flat @@ -1667,48 +1749,57 @@ def sharded_init_partition_spec_fn(params, m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec) local_stats_flat.append( - LocalShardedParameterStats( - QuantizedValue( - param_pspec, - [], - [], - jnp.float32, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - QuantizedValue( - m1_pspec, - [], - m1_scale_pspec, - qdtype, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - QuantizedValue( - m2_pspec, - [], - m2_scale_pspec, - qdtype, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - init_training_metrics_pspec(generate_training_metrics,), - index_start, - sizes)) + LocalShardedParameterStats( + QuantizedValue( + param_pspec, + [], + [], + jnp.float32, + False, # pytype: disable=wrong-arg-types # numpy-scalars + list(param.shape), + ), + QuantizedValue( + m1_pspec, + [], + m1_scale_pspec, + qdtype, + False, # pytype: disable=wrong-arg-types # numpy-scalars + list(param.shape), + ), + QuantizedValue( + m2_pspec, + [], + m2_scale_pspec, + qdtype, + False, # pytype: disable=wrong-arg-types # numpy-scalars + list(param.shape), + ), + init_training_metrics_pspec( + generate_training_metrics, + ), + index_start, + sizes, + ) + ) local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) - global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, - partition_spec_for_statistics, - jax.sharding.PartitionSpec()) + global_stats = GlobalShardedParameterStats( + partition_spec_for_statistics, + partition_spec_for_statistics, + jax.sharding.PartitionSpec(), + ) count_pspec = jax.sharding.PartitionSpec() return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars - count=count_pspec, - stats=ShardedShampooStats(global_stats, local_stats)) + count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats) + ) def sharded_init_shape_and_dtype_fn(params): """Returns a parallel state tree with shape, dtype associated with state. - Args: - params: A pytree with params. - """ + Args: + params: A pytree with params. + """ # Parallel lists of spec, and params. params_flat, treedef = jax.tree_util.tree_flatten(params) assert params_flat @@ -1739,31 +1830,39 @@ def sharded_init_shape_and_dtype_fn(params): diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype] local_stats_flat.append( - LocalShardedParameterStats( - QuantizedValue( - diagonal_statistics_shape_and_dtype, - [], - [], # pytype: disable=wrong-arg-types # numpy-scalars - jnp.float32, - False, - list(param.shape)), - QuantizedValue(m1_shape_and_dtype, [], - m1_scale_shape_and_dtype, - qdtype, - False, - list(param.shape)), - QuantizedValue(m2_shape_and_dtype, [], - m2_scale_shape_and_dtype, - qdtype, - False, - list(param.shape)), - init_training_metrics_shapes( - len(sizes), - generate_training_metrics, - ), - index_start, - sizes, - )) + LocalShardedParameterStats( + QuantizedValue( + diagonal_statistics_shape_and_dtype, + [], + [], # pytype: disable=wrong-arg-types # numpy-scalars + jnp.float32, + False, + list(param.shape), + ), + QuantizedValue( + m1_shape_and_dtype, + [], + m1_scale_shape_and_dtype, + qdtype, + False, + list(param.shape), + ), + QuantizedValue( + m2_shape_and_dtype, + [], + m2_scale_shape_and_dtype, + qdtype, + False, + list(param.shape), + ), + init_training_metrics_shapes( + len(sizes), + generate_training_metrics, + ), + index_start, + sizes, + ) + ) local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) max_statistics_size = _max_statistics_size_from_params(params_flat) @@ -1773,29 +1872,36 @@ def sharded_init_shape_and_dtype_fn(params): num_statistics = num_devices_for_pjit max_statistics_size = block_size statistics_shape = [ - num_statistics, max_statistics_size, max_statistics_size + num_statistics, + max_statistics_size, + max_statistics_size, ] preconditioners_shape = [ - num_statistics, max_statistics_size, precond_dim(max_statistics_size) + num_statistics, + max_statistics_size, + precond_dim(max_statistics_size), ] global_stats = GlobalShardedParameterStats( - [statistics_shape, jnp.float32], [preconditioners_shape, jnp.float32], - [[num_statistics], jnp.int32]) + [statistics_shape, jnp.float32], + [preconditioners_shape, jnp.float32], + [[num_statistics], jnp.int32], + ) return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars - count=[[], jnp.float32], - stats=ShardedShampooStats(global_stats, local_stats)) + count=[[], jnp.float32], + stats=ShardedShampooStats(global_stats, local_stats), + ) def sharded_update_fn(grads, state, params): """Transform the input gradient and update all statistics in sharded mode. - Args: - grads: the gradient tensors for the parameters. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. + Args: + grads: the gradient tensors for the parameters. + state: a named tuple containing the state of the optimizer + params: the parameters that should be updated. - Returns: - A tuple containing the new parameters and the new optimizer state. - """ + Returns: + A tuple containing the new parameters and the new optimizer state. + """ params_flat, treedef = jax.tree_util.tree_flatten(params) grads_flat = treedef.flatten_up_to(grads) @@ -1803,43 +1909,45 @@ def sharded_update_fn(grads, state, params): local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) stats_flat = [] for local_stat in local_stats_flat: - stats_flat.append(_convert_to_parameter_stats( + stats_flat.append( + _convert_to_parameter_stats( global_stats, local_stat, - )) + ) + ) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), - grads_flat, - stats_flat, - params_flat) + lambda g, s, p: _compute_stats(g, s, p, state.count), + grads_flat, + stats_flat, + params_flat, + ) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), - grads_flat, - new_stats_flat, - params_flat) + lambda g, s, p: _transform_grad(g, s, p, state.count), + grads_flat, + new_stats_flat, + params_flat, + ) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_util.tree_unflatten(treedef, updates_flat) new_local_stats_flat = [] for new_stat, local_stat in zip(new_stats_flat, local_stats_flat): new_local_stats_flat.append( - _convert_from_parameter_stats( - new_stat, - local_stat, - )) + _convert_from_parameter_stats( + new_stat, + local_stat, + ) + ) max_size = global_stats.statistics.shape[1] new_padded_statistics = [] padding_starts = [] for stat in new_stats_flat: new_padded_statistics.extend( - [pad_square_matrix(stat, max_size) for stat in stat.statistics]) + [pad_square_matrix(stat, max_size) for stat in stat.statistics] + ) padding_starts.extend([len(stat) for stat in stat.statistics]) # Create global stats @@ -1857,7 +1965,8 @@ def sharded_update_fn(grads, state, params): stat_dtype = new_padded_statistics[0].dtype new_padded_statistics.extend( - [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]) + [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] + ) padding_starts += [0] * to_pad if reuse_preconditioner: @@ -1865,29 +1974,30 @@ def sharded_update_fn(grads, state, params): for stat in new_stats_flat: prev_preconditioners.extend(stat.preconditioners) prev_padded_preconditioners = pad_and_maybe_zero_preconditioners( - prev_preconditioners, - len(new_padded_statistics), - max_size, - state.count) + prev_preconditioners, len(new_padded_statistics), max_size, state.count + ) else: prev_padded_preconditioners = None new_stacked_padded_statistics = jnp.stack(new_padded_statistics) new_stacked_padded_statistics = pjit.with_sharding_constraint( - new_stacked_padded_statistics, statistics_partition_spec) + new_stacked_padded_statistics, statistics_partition_spec + ) stacked_padding_starts = jnp.array(padding_starts, jnp.int32) prev_stacked_padded_preconditioners = _maybe(jnp.stack)( - prev_padded_preconditioners) + prev_padded_preconditioners + ) prev_stacked_padded_preconditioners = _maybe(pjit.with_sharding_constraint)( - prev_padded_preconditioners, statistics_partition_spec) + prev_padded_preconditioners, statistics_partition_spec + ) def _internal_inverse_pth_root_all(): preconditioners, metrics = _matrix_inverse_pth_root_pjit( - new_stacked_padded_statistics, - global_stats.exponents, - stacked_padding_starts, - prev_stacked_padded_preconditioners, - statistics_partition_spec, + new_stacked_padded_statistics, + global_stats.exponents, + stacked_padding_starts, + prev_stacked_padded_preconditioners, + statistics_partition_spec, ) return preconditioners, metrics @@ -1903,39 +2013,47 @@ def _internal_inverse_pth_root_all(): preconditioners_init = new_stacked_padded_statistics[:, :, :pd] n = new_stacked_padded_statistics.shape[0] metrics_init = cast( - TrainingMetrics, - init_training_metrics( - n, - generate_training_metrics=True, - )) + TrainingMetrics, + init_training_metrics( + n, + generate_training_metrics=True, + ), + ) new_errors = jnp.ones_like(metrics_init.inverse_pth_root_errors) * ( - inverse_failure_threshold) + inverse_failure_threshold + ) metrics_init = metrics_init.replace(inverse_pth_root_errors=new_errors) init_state = [preconditioners_init, metrics_init] new_preconditioners, metrics = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) + perform_step, _internal_inverse_pth_root_all, init_state + ) if generate_training_metrics: new_local_stats_flat = _add_metrics_into_local_stats( - new_local_stats_flat, metrics, ~perform_step) - new_local_stats = jax.tree_util.tree_unflatten(treedef, - new_local_stats_flat) + new_local_stats_flat, metrics, ~perform_step + ) + new_local_stats = jax.tree_util.tree_unflatten( + treedef, new_local_stats_flat + ) errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors >= inverse_failure_threshold + ).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( - predicate * global_stats.preconditioners + - (1.0 - predicate) * new_preconditioners) + predicate * global_stats.preconditioners + + (1.0 - predicate) * new_preconditioners + ) new_global_stats = GlobalShardedParameterStats( - new_stacked_padded_statistics, - new_conditional_preconditioners, - global_stats.exponents) + new_stacked_padded_statistics, + new_conditional_preconditioners, + global_stats.exponents, + ) new_shampoo_state = ShampooState( - count=state.count + 1, - stats=ShardedShampooStats(new_global_stats, new_local_stats)) + count=state.count + 1, + stats=ShardedShampooStats(new_global_stats, new_local_stats), + ) return updates, new_shampoo_state def init_fn(params): @@ -1948,13 +2066,13 @@ def _init(param): if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() statistics = [ - matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes + matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes ] # If the preconditioner is using a low-rank representation, initialize # it to zero instead of an invalid eye. preconditioners = [ - jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1]) - for s in shapes + jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1]) + for s in shapes ] diagonal_statistics = [] @@ -1967,25 +2085,28 @@ def _init(param): momentum = jnp.zeros_like(param) return ParameterStats( - diagonal_statistics, - statistics, - preconditioners, - # _quantize_diagonal_statistics(diagonal_statistics), - # _maybe_quantize_statistics(statistics), - # _maybe_quantize_preconditioners(preconditioners), - diagonal_momentum, - momentum, - init_training_metrics( - len(statistics), - generate_training_metrics, - )) + diagonal_statistics, + statistics, + preconditioners, + # _quantize_diagonal_statistics(diagonal_statistics), + # _maybe_quantize_statistics(statistics), + # _maybe_quantize_preconditioners(preconditioners), + diagonal_momentum, + momentum, + init_training_metrics( + len(statistics), + generate_training_metrics, + ), + ) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params) + ) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( - s > skip_preconditioning_dim_size_gt for s in param.shape) + s > skip_preconditioning_dim_size_gt for s in param.shape + ) def _compute_stats(grad, state, param, step): """Compute per-parameter statistics.""" @@ -1997,105 +2118,117 @@ def _compute_stats(grad, state, param, step): def compute_updated_statistics(): return preconditioner.updated_statistics_from_grad( - state.statistics, - grad, - w1=w1, - w2=w2, - to_float=_to_float, - from_float=lambda x: x, - # from_float=lambda x: _maybe_quantize_statistics([x])[0], - precision=tensordot_precision, + state.statistics, + grad, + w1=w1, + w2=w2, + to_float=_to_float, + from_float=lambda x: x, + # from_float=lambda x: _maybe_quantize_statistics([x])[0], + precision=tensordot_precision, ) if statistics_compute_steps > 1: perform_step = step % statistics_compute_steps == 0 init_state = state.statistics new_statistics = list( - efficient_cond(perform_step, compute_updated_statistics, - init_state)) + efficient_cond(perform_step, compute_updated_statistics, init_state) + ) else: new_statistics = compute_updated_statistics() - return ParameterStats(state.diagonal_statistics, - new_statistics, - state.preconditioners, - state.diagonal_momentum, - state.momentum, - state.training_metrics) + return ParameterStats( + state.diagonal_statistics, + new_statistics, + state.preconditioners, + state.diagonal_momentum, + state.momentum, + state.training_metrics, + ) mi_pth_root = functools.partial( - matrix_inverse_pth_root, - ridge_epsilon=matrix_epsilon, - precision=precision, - relative_matrix_epsilon=relative_matrix_epsilon, - lobpcg_topk_precondition=lobpcg_topk_precondition, - lobpcg_max_iter=lobpcg_max_iter, - eigh=eigh) + matrix_inverse_pth_root, + ridge_epsilon=matrix_epsilon, + precision=precision, + relative_matrix_epsilon=relative_matrix_epsilon, + lobpcg_topk_precondition=lobpcg_topk_precondition, + lobpcg_max_iter=lobpcg_max_iter, + eigh=eigh, + ) def _matrix_inverse_pth_root_vmap(xs, ps, padding_starts, prev): return jax.vmap(mi_pth_root)( - xs, ps, padding_start=padding_starts, prev=prev) + xs, ps, padding_start=padding_starts, prev=prev + ) - def _matrix_inverse_pth_root_pjit(xs, - ps, - padding_starts, - prev_preconds=None, - statistics_partition_spec=None): + def _matrix_inverse_pth_root_pjit( + xs, ps, padding_starts, prev_preconds=None, statistics_partition_spec=None + ): # Partition the concatenated statistics matrix across all cores. pspec_for_partition = preconditioner_partition_spec partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition) if preconditioner_partition_spec: partitioned_ps_spec = jax.sharding.PartitionSpec( - preconditioner_partition_spec[0]) + preconditioner_partition_spec[0] + ) else: partitioned_ps_spec = None partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec) partitioned_prev_preconds = _maybe(pjit.with_sharding_constraint)( - prev_preconds, preconditioner_partition_spec) + prev_preconds, preconditioner_partition_spec + ) partitioned_padding_starts = pjit.with_sharding_constraint( - padding_starts, partitioned_ps_spec) # paddings are scalars like ps. + padding_starts, partitioned_ps_spec + ) # paddings are scalars like ps. # Run matrix inverse pth root on each shard. partitioned_preconditioners, partitioned_metrics = ( - _matrix_inverse_pth_root_vmap( - partitioned_xs, - partitioned_ps, - partitioned_padding_starts, - prev=partitioned_prev_preconds)) + _matrix_inverse_pth_root_vmap( + partitioned_xs, + partitioned_ps, + partitioned_padding_starts, + prev=partitioned_prev_preconds, + ) + ) # Reshard output to have the same PSpec as input. This is required to avoid # vmap seeing the full set of statistics. partitioned_preconditioners = pjit.with_sharding_constraint( - partitioned_preconditioners, pspec_for_partition) + partitioned_preconditioners, pspec_for_partition + ) # Recombine the outputs at each core. - preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners, - statistics_partition_spec) - metrics = pjit.with_sharding_constraint(partitioned_metrics, - jax.sharding.PartitionSpec()) + preconditioners = pjit.with_sharding_constraint( + partitioned_preconditioners, statistics_partition_spec + ) + metrics = pjit.with_sharding_constraint( + partitioned_metrics, jax.sharding.PartitionSpec() + ) return preconditioners, metrics - def _pmap_compute_preconditioners(states, - step, - statistics, - num_statistics_per_state, - original_shapes, - exponents, - max_size, - prev_preconditioners): + def _pmap_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ): """Computes preconditioners for given statistics in states in PMAP mode. - Args: - states: A list of optimizer states. - step: Current step number - statistics: A list of statistics for all variables (for every dim) - num_statistics_per_state: Number of statistis per state to reconstruct - output states. - original_shapes: A list of shapes of the statistics. - exponents: Exponent power to use for inverse-pth roots. - max_size: Maximum dim of the statistics to pad. - prev_preconditioners: Previously available preconditioner. - - Returns: - New optimizer states after computing the preconditioner. - """ + Args: + states: A list of optimizer states. + step: Current step number + statistics: A list of statistics for all variables (for every dim) + num_statistics_per_state: Number of statistis per state to reconstruct + output states. + original_shapes: A list of shapes of the statistics. + exponents: Exponent power to use for inverse-pth roots. + max_size: Maximum dim of the statistics to pad. + prev_preconditioners: Previously available preconditioner. + + Returns: + New optimizer states after computing the preconditioner. + """ if batch_axis_name: num_devices = lax.psum(1, batch_axis_name) else: @@ -2103,13 +2236,15 @@ def _pmap_compute_preconditioners(states, num_statistics = len(statistics) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [ - pad_square_matrix(stat, max_size) for stat in statistics + pad_square_matrix(stat, max_size) for stat in statistics ] to_pad = -num_statistics % num_devices - packed_statistics.extend([ + packed_statistics.extend( + [ jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad) - ]) + ] + ) exponents.extend([1 for _ in range(to_pad)]) paddings = [len(stat) for stat in statistics] + [0] * to_pad @@ -2119,7 +2254,8 @@ def _pmap_compute_preconditioners(states, if reuse_preconditioner: assert len(prev_preconditioners) == num_statistics packed_preconditioners = pad_and_maybe_zero_preconditioners( - prev_preconditioners, len(packed_statistics), max_size, step) + prev_preconditioners, len(packed_statistics), max_size, step + ) else: packed_preconditioners = None @@ -2132,10 +2268,10 @@ def _internal_inverse_pth_root_all(): if batch_axis_name: current_replica = lax.axis_index(batch_axis_name) preconditioners, metrics = _matrix_inverse_pth_root_vmap( - all_statistics[current_replica], - all_exponents[current_replica], - all_paddings[current_replica], - _maybe_ix(all_preconditioners, current_replica), + all_statistics[current_replica], + all_exponents[current_replica], + all_paddings[current_replica], + _maybe_ix(all_preconditioners, current_replica), ) preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) @@ -2143,14 +2279,15 @@ def _internal_inverse_pth_root_all(): metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( - all_statistics[0], - all_exponents[0], - all_paddings[0], - _maybe_ix(all_preconditioners, 0), + all_statistics[0], + all_exponents[0], + all_paddings[0], + _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) metrics = jax.tree.map( - functools.partial(jnp.expand_dims, axis=0), metrics) + functools.partial(jnp.expand_dims, axis=0), metrics + ) metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2163,40 +2300,57 @@ def _internal_inverse_pth_root_all(): # shaped tensors. Note statistics will be ignored as we are passing in # a large error value. preconditioners_init = [ - s[:, :precond_dim(s.shape[0])] for s in packed_statistics + s[:, : precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) metrics_init = jax.tree.map( - lambda x: [x] * n, - default_training_metrics().replace( - inverse_pth_root_errors=inverse_failure_threshold)) + lambda x: [x] * n, + default_training_metrics().replace( + inverse_pth_root_errors=inverse_failure_threshold + ), + ) init_state = [preconditioners_init, metrics_init] preconditioners_flat, metrics_flat = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) + perform_step, _internal_inverse_pth_root_all, init_state + ) def _skip(error): condition = jnp.logical_or( - jnp.isnan(error), error >= inverse_failure_threshold) + jnp.isnan(error), error >= inverse_failure_threshold + ) return condition.astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( - _skip(error), lambda _: old_p, lambda _: new_p, operand=None) + _skip(error), lambda _: old_p, lambda _: new_p, operand=None + ) new_preconditioners_flat = [] new_errors_flat = metrics_flat.inverse_pth_root_errors - for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, - prev_preconditioners, new_errors_flat): + for p, shape, prev_p, error in zip( + preconditioners_flat, + original_shapes, + prev_preconditioners, + new_errors_flat, + ): new_preconditioners_flat.append( - _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) + _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) + ) - assert len(states) == (len(num_statistics_per_state), - f"{len(states)} vs {len(num_statistics_per_state)}") + assert len(states) == ( + len(num_statistics_per_state), + f'{len(states)} vs {len(num_statistics_per_state)}', + ) assert len(new_preconditioners_flat) == num_statistics assert len(new_errors_flat) == len(packed_statistics), ( - len(new_errors_flat), len(packed_statistics)) + len(new_errors_flat), + len(packed_statistics), + ) assert len(new_errors_flat) == num_statistics + to_pad, ( - len(new_errors_flat), num_statistics, to_pad) + len(new_errors_flat), + num_statistics, + to_pad, + ) # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] @@ -2206,26 +2360,31 @@ def _select_preconditioner(error, new_p, old_p): if num_statistics == 0: preconditioners_for_states.append([]) metrics_for_states.append( - init_training_metrics(0, generate_training_metrics)) + init_training_metrics(0, generate_training_metrics) + ) else: - preconditioners_for_state = new_preconditioners_flat[idx:idx + - num_statistics] + preconditioners_for_state = new_preconditioners_flat[ + idx : idx + num_statistics + ] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. metrics_for_state = jax.tree.map( - lambda x: jnp.stack(x[idx:idx + num_statistics]), - metrics_flat, - is_leaf=lambda x: isinstance(x, list)) + lambda x: jnp.stack(x[idx : idx + num_statistics]), + metrics_flat, + is_leaf=lambda x: isinstance(x, list), + ) assert jax.tree_util.tree_all( - jax.tree.map(lambda x: len(state.statistics) == len(x), - metrics_for_state)) + jax.tree.map( + lambda x: len(state.statistics) == len(x), metrics_for_state + ) + ) # If we skipped preconditioner computation, record old metrics. - metrics_for_state = efficient_cond(perform_step, - lambda: [metrics_for_state], - [state.training_metrics])[0] + metrics_for_state = efficient_cond( + perform_step, lambda: [metrics_for_state], [state.training_metrics] + )[0] # pylint:enable=cell-var-from-loop else: metrics_for_state = optax.MaskedNode() @@ -2234,32 +2393,36 @@ def _select_preconditioner(error, new_p, old_p): idx += num_statistics new_states = [] for state, new_preconditioners, new_metrics in zip( - states, preconditioners_for_states, metrics_for_states): + states, preconditioners_for_states, metrics_for_states + ): # Note the preconditioner may have been skipped, but we still update the # metrics with the new error values; whether the preconditioner that's # actively being used is stale can be derived from the new_metrics # being greater than the failure threshold. new_states.append( - ParameterStats(state.diagonal_statistics, - state.statistics, - new_preconditioners, - state.diagonal_momentum, - state.momentum, - new_metrics)) + ParameterStats( + state.diagonal_statistics, + state.statistics, + new_preconditioners, + state.diagonal_momentum, + state.momentum, + new_metrics, + ) + ) return new_states def _compute_preconditioners(states, params, step): """Computes preconditioners for given statistics in states. - Args: - states: A list of optimizer states. - params: A list of params. - step: Current step number + Args: + states: A list of optimizer states. + params: A list of params. + step: Current step number - Returns: - New optimizer states after computing the preconditioner. - """ + Returns: + New optimizer states after computing the preconditioner. + """ statistics = [] num_statistics_per_state = [] original_shapes = [] @@ -2274,8 +2437,11 @@ def _compute_preconditioners(states, params, step): if num_statistics > 0: preconditioner = preconditioner_from_params(param) for statistic in state.statistics: - exponents.append(preconditioner.exponent_for_preconditioner( - ) if exponent_override == 0 else exponent_override) + exponents.append( + preconditioner.exponent_for_preconditioner() + if exponent_override == 0 + else exponent_override + ) original_shapes_for_state.append(statistic.shape) max_size = max(max_size, statistic.shape[0]) @@ -2283,14 +2449,16 @@ def _compute_preconditioners(states, params, step): prev_preconditioners.extend(state.preconditioners) original_shapes.extend(original_shapes_for_state) - return _pmap_compute_preconditioners(states, - step, - statistics, - num_statistics_per_state, - original_shapes, - exponents, - max_size, - prev_preconditioners) + return _pmap_compute_preconditioners( + states, + step, + statistics, + num_statistics_per_state, + original_shapes, + exponents, + max_size, + prev_preconditioners, + ) def _transform_grad(grad, state, param, step): """Transform per-parameter gradients.""" @@ -2298,21 +2466,25 @@ def _transform_grad(grad, state, param, step): sgd_update = grad new_diagonal_statistics = state.diagonal_statistics - if (graft_type == GraftingType.ADAGRAD or - graft_type == GraftingType.ADAGRAD_NORMALIZED): - + if ( + graft_type == GraftingType.ADAGRAD + or graft_type == GraftingType.ADAGRAD_NORMALIZED + ): scaled_grad = grad if graft_type == GraftingType.ADAGRAD_NORMALIZED: scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON) new_diagonal_statistics = ( - state.diagonal_statistics.to_float() + jnp.square(scaled_grad)) + state.diagonal_statistics.to_float() + jnp.square(scaled_grad) + ) adagrad_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) + jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon + ) grafting_update = adagrad_update - elif (graft_type == GraftingType.RMSPROP or - graft_type == GraftingType.RMSPROP_NORMALIZED): - + elif ( + graft_type == GraftingType.RMSPROP + or graft_type == GraftingType.RMSPROP_NORMALIZED + ): scaled_grad = grad if graft_type == GraftingType.RMSPROP_NORMALIZED: scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON) @@ -2321,15 +2493,19 @@ def _transform_grad(grad, state, param, step): w2 = jnp.where(beta2 == 1.0, beta2, 1.0 - beta2) new_diagonal_statistics = ( - w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad)) + w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad) + ) rmsprop_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) + jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon + ) if clip_by_scaled_gradient_norm: scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( - jnp.sqrt(float(rmsprop_update.size))) + jnp.sqrt(float(rmsprop_update.size)) + ) clipping_denom = jnp.maximum( - 1., scaled_grad_norm / clip_by_scaled_gradient_norm) + 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm + ) rmsprop_update /= clipping_denom grafting_update = rmsprop_update @@ -2349,12 +2525,14 @@ def _transform_grad(grad, state, param, step): precond_grad = grad if not _skip_preconditioning(param): - precond_grad = preconditioner.preconditioned_grad(precond_grad, - state.preconditioners) + precond_grad = preconditioner.preconditioned_grad( + precond_grad, state.preconditioners + ) else: if graft_type == GraftingType.NONE: - logging.error("skipping preconditioning without grafting for param %s", - param) + logging.error( + 'skipping preconditioning without grafting for param %s', param + ) precond_grad = grafting_update grafting_update_norm = jnp.linalg.norm(grafting_update) @@ -2369,39 +2547,49 @@ def _transform_grad(grad, state, param, step): shampoo_update_with_wd = shampoo_update grafting_update_with_wd = grafting_update - if (weight_decay != 0 and weight_decay is not None and - not decoupled_weight_decay): + if ( + weight_decay != 0 + and weight_decay is not None + and not decoupled_weight_decay + ): shampoo_update_with_wd = shampoo_update + weight_decay * param grafting_update_with_wd = grafting_update + weight_decay * param w = (1.0 - beta1) if moving_average_for_momentum else 1.0 shampoo_update_with_wd_momentum = ( - state.momentum * beta1 + w * shampoo_update_with_wd) + state.momentum * beta1 + w * shampoo_update_with_wd + ) grafting_update_with_wd_momentum = ( - state.diagonal_momentum * beta1 + w * grafting_update_with_wd) + state.diagonal_momentum * beta1 + w * grafting_update_with_wd + ) run_shampoo = (step >= start_preconditioning_step).astype( - grafting_update_with_wd_momentum.dtype) + grafting_update_with_wd_momentum.dtype + ) momentum_update = ( - run_shampoo * shampoo_update_with_wd_momentum + - (1.0 - run_shampoo) * grafting_update_with_wd_momentum) + run_shampoo * shampoo_update_with_wd_momentum + + (1.0 - run_shampoo) * grafting_update_with_wd_momentum + ) wd_update = ( - run_shampoo * shampoo_update_with_wd + - (1.0 - run_shampoo) * grafting_update_with_wd) + run_shampoo * shampoo_update_with_wd + + (1.0 - run_shampoo) * grafting_update_with_wd + ) nesterov_momentum_update = momentum_update if nesterov: nesterov_momentum_update = w * wd_update + beta1 * momentum_update - if (weight_decay != 0 and weight_decay is not None and - decoupled_weight_decay): + if ( + weight_decay != 0 and weight_decay is not None and decoupled_weight_decay + ): nesterov_momentum_update = ( - nesterov_momentum_update + lr * weight_decay * param) + nesterov_momentum_update + lr * weight_decay * param + ) momentum_multiplier = lr if decoupled_learning_rate else 1.0 transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update @@ -2409,26 +2597,28 @@ def _transform_grad(grad, state, param, step): new_diagonal_momentum = grafting_update_with_wd_momentum new_momentum = shampoo_update_with_wd_momentum - param_stats = ParameterStats(new_diagonal_statistics, - state.statistics, - state.preconditioners, - new_diagonal_momentum, - new_momentum, - state.training_metrics) + param_stats = ParameterStats( + new_diagonal_statistics, + state.statistics, + state.preconditioners, + new_diagonal_momentum, + new_momentum, + state.training_metrics, + ) return transformed_update, param_stats def update_fn(grads, state, params): """Transform the input gradient and update all statistics. - Args: - grads: the gradient tensors for the parameters and any custom - gradients for preconditioners. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. + Args: + grads: the gradient tensors for the parameters and any custom + gradients for preconditioners. + state: a named tuple containing the state of the optimizer + params: the parameters that should be updated. - Returns: - A tuple containing the new parameters and the new optimizer state. - """ + Returns: + A tuple containing the new parameters and the new optimizer state. + """ grads_custom = None if custom_preconditioner and isinstance(grads, tuple): grads, grads_custom = grads @@ -2442,23 +2632,21 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), - stats_grads, - stats_flat, - params_flat) - - new_stats_flat = _compute_preconditioners(new_stats_flat, - params_flat, - state.count) + lambda g, s, p: _compute_stats(g, s, p, state.count), + stats_grads, + stats_flat, + params_flat, + ) + + new_stats_flat = _compute_preconditioners( + new_stats_flat, params_flat, state.count + ) outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), - grads_flat, - new_stats_flat, - params_flat) + lambda g, s, p: _transform_grad(g, s, p, state.count), + grads_flat, + new_stats_flat, + params_flat, + ) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_util.tree_unflatten(treedef, updates_flat) new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat) @@ -2472,9 +2660,10 @@ def update_fn(grads, state, params): def _init_fns(unused_params): return InitFnState( - init_fn=opt_init_fn, - pspec_fn=sharded_init_partition_spec_fn, - shape_and_dtype_fn=sharded_init_shape_and_dtype_fn) + init_fn=opt_init_fn, + pspec_fn=sharded_init_partition_spec_fn, + shape_and_dtype_fn=sharded_init_shape_and_dtype_fn, + ) opt_update_fn = sharded_update_fn return optax.GradientTransformation(_init_fns, opt_update_fn) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 2cd054062..8bf4d2dc5 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -3,24 +3,27 @@ import functools from typing import Any, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax +from flax import jax_utils +from jax import lax from algoperf import spec -from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \ - distributed_shampoo +from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import ( + distributed_shampoo, +) _GRAD_CLIP_EPS = 1e-6 -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Shampoo optimizer and a learning rate schedule.""" del model_params del model_state @@ -30,102 +33,116 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps, + ) cosine_steps = max(step_hint - warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps] + ) return schedule_fn # Create optimizer + LR schedule. lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) opt_init_fn, opt_update_fn = distributed_shampoo( - learning_rate=lr_schedule_fn, - beta1=1.0 - hyperparameters.one_minus_beta1, - beta2=hyperparameters.beta2, - weight_decay=hyperparameters.weight_decay, - batch_axis_name='batch', - eigh=False) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + learning_rate=lr_schedule_fn, + beta1=1.0 - hyperparameters.one_minus_beta1, + beta2=hyperparameters.beta2, + weight_decay=hyperparameters.weight_decay, + batch_axis_name='batch', + eigh=False, + ) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -142,37 +159,43 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + outputs = pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters @@ -208,14 +231,16 @@ def get_batch_size(workload_name): raise ValueError(f'Unsupported workload name: {workload_name}.') -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json index 9d804ba0e..58f6f4fd1 100644 --- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json +++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json @@ -1,23 +1,29 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 1e-2, "max": 0.15, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "min": 1e-2, + "max": 0.15, + "scaling": "log" + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json index b8bd2ea49..5a7c27be7 100644 --- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json +++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json @@ -1,23 +1,27 @@ { - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } + "learning_rate": { + "min": 1e-4, + "max": 1e-2, + "scaling": "log" + }, + "one_minus_beta1": { + "feasible_points": [0.1] + }, + "beta2": { + "feasible_points": [0.999] + }, + "warmup_factor": { + "feasible_points": [0.05] + }, + "weight_decay": { + "min": 5e-3, + "max": 1.0, + "scaling": "log" + }, + "label_smoothing": { + "feasible_points": [0.1, 0.2] + }, + "dropout_rate": { + "feasible_points": [0.0, 0.1] + } } diff --git a/reference_algorithms/target_setting_algorithms/cosine_warmup.py b/reference_algorithms/target_setting_algorithms/cosine_warmup.py index 116ebc555..6a2241732 100644 --- a/reference_algorithms/target_setting_algorithms/cosine_warmup.py +++ b/reference_algorithms/target_setting_algorithms/cosine_warmup.py @@ -1,35 +1,37 @@ """Implementions of a linear warmup then cosine decay LR schedule.""" import optax -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps, + ) cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1) cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps + ) schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hyperparameters.warmup_steps]) + schedules=[warmup_fn, cosine_fn], boundaries=[hyperparameters.warmup_steps] + ) return schedule_fn def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): warmup = LinearLR( - optimizer, - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) + optimizer, + start_factor=1e-10, + end_factor=1.0, + total_iters=hyperparameters.warmup_steps, + ) cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1) cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) return SequentialLR( - optimizer, - schedulers=[warmup, cosine_decay], - milestones=[hyperparameters.warmup_steps]) + optimizer, + schedulers=[warmup, cosine_decay], + milestones=[hyperparameters.warmup_steps], + ) diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json index cab6fd5f7..6061940a9 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json @@ -1,27 +1,17 @@ { "learning_rate": { - "feasible_points": [ - 0.0033313215673016375 - ] + "feasible_points": [0.0033313215673016375] }, "beta1": { - "feasible_points": [ - 0.948000082541717 - ] + "feasible_points": [0.948000082541717] }, "beta2": { - "feasible_points": [ - 0.9987934318891598 - ] + "feasible_points": [0.9987934318891598] }, "warmup_steps": { - "feasible_points": [ - 159 - ] + "feasible_points": [159] }, "weight_decay": { - "feasible_points": [ - 0.0035784380304876183 - ] + "feasible_points": [0.0035784380304876183] } } diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json index bd6c9702f..110138607 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json @@ -1,28 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.002517072211464665 - ] - }, - "beta1": { - "feasible_points": [ - 0.9908351643533544 - ] - }, - "beta2": { - "feasible_points": [ - 0.9859568907533993 - ] - }, - "warmup_steps": { - "feasible_points": [ - 799 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.12274552870237089 - ] - } + "learning_rate": { + "feasible_points": [0.002517072211464665] + }, + "beta1": { + "feasible_points": [0.9908351643533544] + }, + "beta2": { + "feasible_points": [0.9859568907533993] + }, + "warmup_steps": { + "feasible_points": [799] + }, + "weight_decay": { + "feasible_points": [0.12274552870237089] } - \ No newline at end of file +} diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json index 8d128dae1..a7f52681d 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json @@ -1,28 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.05493199486120455 - ] - }, - "beta1": { - "feasible_points": [ - 0.954922991734919 - ] - }, - "beta2": { - "feasible_points": [ - 0.9986188074995163 - ] - }, - "warmup_steps": { - "feasible_points": [ - 799 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.00011065469792077193 - ] - } + "learning_rate": { + "feasible_points": [0.05493199486120455] + }, + "beta1": { + "feasible_points": [0.954922991734919] + }, + "beta2": { + "feasible_points": [0.9986188074995163] + }, + "warmup_steps": { + "feasible_points": [799] + }, + "weight_decay": { + "feasible_points": [0.00011065469792077193] } - \ No newline at end of file +} diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json index a33ae2ff5..31ce92bd1 100644 --- a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json @@ -1,28 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.001493629901423942 - ] - }, - "beta1": { - "feasible_points": [ - 0.9592129978682067 - ] - }, - "beta2": { - "feasible_points": [ - 0.9824918272399145 - ] - }, - "warmup_steps": { - "feasible_points": [ - 399 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.00038587516415285595 - ] - } + "learning_rate": { + "feasible_points": [0.001493629901423942] + }, + "beta1": { + "feasible_points": [0.9592129978682067] + }, + "beta2": { + "feasible_points": [0.9824918272399145] + }, + "warmup_steps": { + "feasible_points": [399] + }, + "weight_decay": { + "feasible_points": [0.00038587516415285595] } - \ No newline at end of file +} diff --git a/reference_algorithms/target_setting_algorithms/data_selection.py b/reference_algorithms/target_setting_algorithms/data_selection.py index 5e70f9f8b..e0d9c0ee9 100644 --- a/reference_algorithms/target_setting_algorithms/data_selection.py +++ b/reference_algorithms/target_setting_algorithms/data_selection.py @@ -4,14 +4,15 @@ def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. Each element of the queue is a batch of training examples and labels. """ diff --git a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json index d8b4ed1b9..894ebb9fb 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json @@ -1,37 +1,23 @@ { - "learning_rate": { - "feasible_points": [ - 0.028609 - ] - }, - "beta1": { - "feasible_points": [ - 0.981543 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1357 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.984398 - ] - }, - "end_factor": { - "feasible_points": [ - 0.01 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.000576 - ] - } + "learning_rate": { + "feasible_points": [0.028609] + }, + "beta1": { + "feasible_points": [0.981543] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [1357] + }, + "decay_steps_factor": { + "feasible_points": [0.984398] + }, + "end_factor": { + "feasible_points": [0.01] + }, + "weight_decay": { + "feasible_points": [0.000576] + } } diff --git a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json index 7a49ea891..a3aa8ea08 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.008334676559764446 - ] - }, - "beta1": { - "feasible_points": [ - 0.8294338711079317 - ] - }, - "beta2": { - "feasible_points": [ - 0.8551723332825868 - ] - }, - "warmup_steps": { - "feasible_points": [ - 2714 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.01371235755699044 - ] - } + "learning_rate": { + "feasible_points": [0.008334676559764446] + }, + "beta1": { + "feasible_points": [0.8294338711079317] + }, + "beta2": { + "feasible_points": [0.8551723332825868] + }, + "warmup_steps": { + "feasible_points": [2714] + }, + "weight_decay": { + "feasible_points": [0.01371235755699044] + } } diff --git a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json index 5516242df..21c5ac87c 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.006173154695175443 - ] - }, - "beta1": { - "feasible_points": [ - 0.8496694604806512 - ] - }, - "beta2": { - "feasible_points": [ - 0.4639437428687345 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1357 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.1679001017957879 - ] - } + "learning_rate": { + "feasible_points": [0.006173154695175443] + }, + "beta1": { + "feasible_points": [0.8496694604806512] + }, + "beta2": { + "feasible_points": [0.4639437428687345] + }, + "warmup_steps": { + "feasible_points": [1357] + }, + "weight_decay": { + "feasible_points": [0.1679001017957879] + } } diff --git a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json index c3f06a686..59d624fe9 100644 --- a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.04037951750205473 - ] - }, - "beta1": { - "feasible_points": [ - 0.9932215932637941 - ] - }, - "beta2": { - "feasible_points": [ - 0.9425306939334134 - ] - }, - "warmup_steps": { - "feasible_points": [ - 542 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.14877061239151607 - ] - } + "learning_rate": { + "feasible_points": [0.04037951750205473] + }, + "beta1": { + "feasible_points": [0.9932215932637941] + }, + "beta2": { + "feasible_points": [0.9425306939334134] + }, + "warmup_steps": { + "feasible_points": [542] + }, + "weight_decay": { + "feasible_points": [0.14877061239151607] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json index 649487c48..941bac70e 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json @@ -1,42 +1,26 @@ { - "learning_rate": { - "feasible_points": [ - 4.131896390902391 - ] - }, - "beta1": { - "feasible_points": [ - 0.9274758113254791 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.9007765761611038 - ] - }, - "end_factor": { - "feasible_points": [ - 0.001 - ] - }, - "weight_decay": { - "feasible_points": [ - 5.6687777311501786e-6 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.2 - ] - } + "learning_rate": { + "feasible_points": [4.131896390902391] + }, + "beta1": { + "feasible_points": [0.9274758113254791] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "decay_steps_factor": { + "feasible_points": [0.9007765761611038] + }, + "end_factor": { + "feasible_points": [0.001] + }, + "weight_decay": { + "feasible_points": [5.6687777311501786e-6] + }, + "label_smoothing": { + "feasible_points": [0.2] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json index 6524f5a5b..f72a4057d 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json @@ -1,37 +1,23 @@ { - "learning_rate": { - "feasible_points": [ - 0.3850582234619253 - ] - }, - "beta1": { - "feasible_points": [ - 0.9845129495436189 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.9504205232618159 - ] - }, - "end_factor": { - "feasible_points": [ - 0.001 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.7359160785435053e-5 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.2 - ] - } + "learning_rate": { + "feasible_points": [0.3850582234619253] + }, + "beta1": { + "feasible_points": [0.9845129495436189] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "decay_steps_factor": { + "feasible_points": [0.9504205232618159] + }, + "end_factor": { + "feasible_points": [0.001] + }, + "weight_decay": { + "feasible_points": [1.7359160785435053e-5] + }, + "label_smoothing": { + "feasible_points": [0.2] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json index 7ad32bb60..f58474cc8 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json @@ -1,37 +1,23 @@ { - "learning_rate": { - "feasible_points": [ - 4.131896390902391 - ] - }, - "beta1": { - "feasible_points": [ - 0.9274758113254791 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.9007765761611038 - ] - }, - "end_factor": { - "feasible_points": [ - 0.01 - ] - }, - "weight_decay": { - "feasible_points": [ - 5.6687777311501786e-6 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.2 - ] - } + "learning_rate": { + "feasible_points": [4.131896390902391] + }, + "beta1": { + "feasible_points": [0.9274758113254791] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "decay_steps_factor": { + "feasible_points": [0.9007765761611038] + }, + "end_factor": { + "feasible_points": [0.01] + }, + "weight_decay": { + "feasible_points": [5.6687777311501786e-6] + }, + "label_smoothing": { + "feasible_points": [0.2] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json index 4556c6235..c63a87214 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.01897755400372091 - ] - }, - "beta1": { - "feasible_points": [ - 0.9666072782043229 - ] - }, - "beta2": { - "feasible_points": [ - 0.99681600289198 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.015653883841116094 - ] - } + "learning_rate": { + "feasible_points": [0.01897755400372091] + }, + "beta1": { + "feasible_points": [0.9666072782043229] + }, + "beta2": { + "feasible_points": [0.99681600289198] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.015653883841116094] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json index 98360bdff..6c7501295 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0008445074561975979 - ] - }, - "beta1": { - "feasible_points": [ - 0.8895758153482813 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08135402759553023 - ] - } + "learning_rate": { + "feasible_points": [0.0008445074561975979] + }, + "beta1": { + "feasible_points": [0.8895758153482813] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.08135402759553023] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json index 98360bdff..6c7501295 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0008445074561975979 - ] - }, - "beta1": { - "feasible_points": [ - 0.8895758153482813 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08135402759553023 - ] - } + "learning_rate": { + "feasible_points": [0.0008445074561975979] + }, + "beta1": { + "feasible_points": [0.8895758153482813] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.08135402759553023] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json index 98360bdff..6c7501295 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0008445074561975979 - ] - }, - "beta1": { - "feasible_points": [ - 0.8895758153482813 - ] - }, - "beta2": { - "feasible_points": [ - 0.9978504782314613 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08135402759553023 - ] - } + "learning_rate": { + "feasible_points": [0.0008445074561975979] + }, + "beta1": { + "feasible_points": [0.8895758153482813] + }, + "beta2": { + "feasible_points": [0.9978504782314613] + }, + "warmup_steps": { + "feasible_points": [6999] + }, + "weight_decay": { + "feasible_points": [0.08135402759553023] + } } diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json index d6f2053ff..94711417c 100644 --- a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.00026032497966327757 - ] - }, - "beta1": { - "feasible_points": [ - 0.9709035036599892 - ] - }, - "beta2": { - "feasible_points": [ - 0.6572080806975734 - ] - }, - "warmup_steps": { - "feasible_points": [ - 13999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.03077045727617869 - ] - } + "learning_rate": { + "feasible_points": [0.00026032497966327757] + }, + "beta1": { + "feasible_points": [0.9709035036599892] + }, + "beta2": { + "feasible_points": [0.6572080806975734] + }, + "warmup_steps": { + "feasible_points": [13999] + }, + "weight_decay": { + "feasible_points": [0.03077045727617869] + } } diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..3fa5e4955 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -1,44 +1,54 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" -from flax import jax_utils + import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 + update_params, +) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_params del model_state del rng target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.weight_decay) + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..92d3f3a8f 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -2,25 +2,30 @@ from typing import Callable -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 + update_params, +) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload, # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=hyperparameters.beta1, - nesterov=False) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=hyperparameters.beta1, + nesterov=False, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps, + ) decay_steps = step_hint - hyperparameters.warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], - boundaries=[hyperparameters.warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], + boundaries=[hyperparameters.warmup_steps], + ) return lr_schedule_fn @@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..6883b17ab 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -3,33 +3,35 @@ from typing import Any, Callable, NamedTuple, Optional, Union import chex -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 + update_params, +) # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: @@ -61,19 +63,22 @@ def nadamw( An (init_fn, update_fn) tuple. """ return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate), + ) # All functions below are forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: +def scale_by_nadam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5, +) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: There seem to be multiple versions of NAdam. The original version is here @@ -109,7 +114,8 @@ def update_fn(updates, state, params=None): mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree.map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat + ) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -117,6 +123,7 @@ def update_fn(updates, state, params=None): class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. mu: optax.Updates nu: optax.Updates @@ -125,7 +132,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" return jax.tree.map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments + ) def _bias_correction(moment, decay, count): @@ -141,31 +149,37 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): return optax.scale(m * learning_rate) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_params del model_state del rng target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.weight_decay) + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..7f4d1cd86 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -2,25 +2,30 @@ from typing import Callable -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 + update_params, +) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_params del model_state @@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload, # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = sgd( - learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.weight_decay, - momentum=hyperparameters.beta1, - nesterov=True) + learning_rate=lr_schedule_fn, + weight_decay=hyperparameters.weight_decay, + momentum=hyperparameters.beta1, + nesterov=True, + ) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn def create_lr_schedule_fn( - step_hint: int, - hyperparameters: spec.Hyperparameters) -> Callable[[int], float]: + step_hint: int, hyperparameters: spec.Hyperparameters +) -> Callable[[int], float]: warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) + init_value=0.0, + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps, + ) decay_steps = step_hint - hyperparameters.warmup_steps polynomial_schedule_fn = optax.polynomial_schedule( - init_value=hyperparameters.learning_rate, - end_value=hyperparameters.learning_rate * hyperparameters.end_factor, - power=1, - transition_steps=int(decay_steps * hyperparameters.decay_steps_factor)) + init_value=hyperparameters.learning_rate, + end_value=hyperparameters.learning_rate * hyperparameters.end_factor, + power=1, + transition_steps=int(decay_steps * hyperparameters.decay_steps_factor), + ) lr_schedule_fn = optax.join_schedules( - schedules=[warmup_fn, polynomial_schedule_fn], - boundaries=[hyperparameters.warmup_steps]) + schedules=[warmup_fn, polynomial_schedule_fn], + boundaries=[hyperparameters.warmup_steps], + ) return lr_schedule_fn @@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): {SGD, Momentum, Nesterov} update. """ return optax.chain( - optax.add_decayed_weights(weight_decay), - optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + optax.add_decayed_weights(weight_decay), + optax.sgd( + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..d9b12f5ca 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,11 +1,12 @@ """Update submission function in Jax.""" + import functools from typing import Any, Dict, List, Optional, Tuple import jax -from jax import lax import jax.numpy as jnp import optax +from jax import lax from algoperf import spec @@ -13,75 +14,84 @@ @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4), +) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, +): def _loss_fn(params): """Loss function used for training.""" logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] return summed_loss, (n_valid_examples, new_model_state) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) + current_param_container + ) # Get correct global mean loss and grad. (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' + ) loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)) + ) if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -98,32 +108,46 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, - label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = ( + pmapped_train_step( # pylint: disable=line-too-long + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) + ) # Log loss, grad_norm. - if ((global_step <= 100 or global_step % 500 == 0) and - workload.metrics_logger is not None): + if ( + global_step <= 100 or global_step % 500 == 0 + ) and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step, + ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json index 482a28931..936833cf3 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.002106913873888147 - ] - }, - "beta1": { - "feasible_points": [ - 0.8231189937738506 - ] - }, - "beta2": { - "feasible_points": [ - 0.8774571227688758 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1199 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.27590534177690645 - ] - } + "learning_rate": { + "feasible_points": [0.002106913873888147] + }, + "beta1": { + "feasible_points": [0.8231189937738506] + }, + "beta2": { + "feasible_points": [0.8774571227688758] + }, + "warmup_steps": { + "feasible_points": [1199] + }, + "weight_decay": { + "feasible_points": [0.27590534177690645] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json index 22f3376b4..faefa750e 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0007852999990476642 - ] - }, - "beta1": { - "feasible_points": [ - 0.6994142393023162 - ] - }, - "beta2": { - "feasible_points": [ - 0.9918636824608852 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.07286322158086678 - ] - } + "learning_rate": { + "feasible_points": [0.0007852999990476642] + }, + "beta1": { + "feasible_points": [0.6994142393023162] + }, + "beta2": { + "feasible_points": [0.9918636824608852] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.07286322158086678] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json index ad200c01b..16ab02525 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.000590120167916659 - ] - }, - "beta1": { - "feasible_points": [ - 0.737199286155609 - ] - }, - "beta2": { - "feasible_points": [ - 0.05919391544031072 - ] - }, - "warmup_steps": { - "feasible_points": [ - 9999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.14128519778326312 - ] - } + "learning_rate": { + "feasible_points": [0.000590120167916659] + }, + "beta1": { + "feasible_points": [0.737199286155609] + }, + "beta2": { + "feasible_points": [0.05919391544031072] + }, + "warmup_steps": { + "feasible_points": [9999] + }, + "weight_decay": { + "feasible_points": [0.14128519778326312] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json index 8297cf0ae..d596dcd2b 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0014446807792420305 - ] - }, - "beta1": { - "feasible_points": [ - 0.7427148812902895 - ] - }, - "beta2": { - "feasible_points": [ - 0.8993064520764248 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.06875136511682291 - ] - } + "learning_rate": { + "feasible_points": [0.0014446807792420305] + }, + "beta1": { + "feasible_points": [0.7427148812902895] + }, + "beta2": { + "feasible_points": [0.8993064520764248] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "weight_decay": { + "feasible_points": [0.06875136511682291] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json index b31b711f7..dbcbecf78 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0035278622506232458 - ] - }, - "beta1": { - "feasible_points": [ - 0.8192305396005781 - ] - }, - "beta2": { - "feasible_points": [ - 0.495850879212151 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.04339748256184769 - ] - } + "learning_rate": { + "feasible_points": [0.0035278622506232458] + }, + "beta1": { + "feasible_points": [0.8192305396005781] + }, + "beta2": { + "feasible_points": [0.495850879212151] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.04339748256184769] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json index e20a2dae1..bbea133cb 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.001308209823469072 - ] - }, - "beta1": { - "feasible_points": [ - 0.9731333693827139 - ] - }, - "beta2": { - "feasible_points": [ - 0.9981232922116359 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.16375311233774334 - ] - } + "learning_rate": { + "feasible_points": [0.001308209823469072] + }, + "beta1": { + "feasible_points": [0.9731333693827139] + }, + "beta2": { + "feasible_points": [0.9981232922116359] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.16375311233774334] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json index 0a9bfb3cf..52fe59d84 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.004958460849689891 - ] - }, - "beta1": { - "feasible_points": [ - 0.863744242567442 - ] - }, - "beta2": { - "feasible_points": [ - 0.6291854735396584 - ] - }, - "warmup_steps": { - "feasible_points": [ - 720 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.1147386261512052 - ] - } + "learning_rate": { + "feasible_points": [0.004958460849689891] + }, + "beta1": { + "feasible_points": [0.863744242567442] + }, + "beta2": { + "feasible_points": [0.6291854735396584] + }, + "warmup_steps": { + "feasible_points": [720] + }, + "weight_decay": { + "feasible_points": [0.1147386261512052] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json index e76a48325..898fc9e36 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0020162740358935045 - ] - }, - "beta1": { - "feasible_points": [ - 0.9604907112078142 - ] - }, - "beta2": { - "feasible_points": [ - 0.8765457000160508 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3600 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.0006149579248633481 - ] - } + "learning_rate": { + "feasible_points": [0.0020162740358935045] + }, + "beta1": { + "feasible_points": [0.9604907112078142] + }, + "beta2": { + "feasible_points": [0.8765457000160508] + }, + "warmup_steps": { + "feasible_points": [3600] + }, + "weight_decay": { + "feasible_points": [0.0006149579248633481] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json index 55f70f9fc..94f150ad3 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.0014446807792420305 - ] - }, - "beta1": { - "feasible_points": [ - 0.7427148812902895 - ] - }, - "beta2": { - "feasible_points": [ - 0.8993064520764248 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1800 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.06875136511682291 - ] - } + "learning_rate": { + "feasible_points": [0.0014446807792420305] + }, + "beta1": { + "feasible_points": [0.7427148812902895] + }, + "beta2": { + "feasible_points": [0.8993064520764248] + }, + "warmup_steps": { + "feasible_points": [1800] + }, + "weight_decay": { + "feasible_points": [0.06875136511682291] + } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json index e5f906688..517e4a455 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.003604759885558324 - ] - }, - "beta1": { - "feasible_points": [ - 0.9931094324430452 - ] - }, - "beta2": { - "feasible_points": [ - 0.9976871843749077 - ] - }, - "warmup_steps": { - "feasible_points": [ - 720 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.120077307855989 - ] - } + "learning_rate": { + "feasible_points": [0.003604759885558324] + }, + "beta1": { + "feasible_points": [0.9931094324430452] + }, + "beta2": { + "feasible_points": [0.9976871843749077] + }, + "warmup_steps": { + "feasible_points": [720] + }, + "weight_decay": { + "feasible_points": [0.120077307855989] + } } diff --git a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json index 0f365a183..266f0f3f5 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 2.4917728606918423 - ] - }, - "beta1": { - "feasible_points": [ - 0.9449369031171744 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "decay_steps_factor": { - "feasible_points": [ - 0.861509027839639 - ] - }, - "end_factor": { - "feasible_points": [ - 0.001 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.2859640541025928e-7 - ] - } + "learning_rate": { + "feasible_points": [2.4917728606918423] + }, + "beta1": { + "feasible_points": [0.9449369031171744] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "decay_steps_factor": { + "feasible_points": [0.861509027839639] + }, + "end_factor": { + "feasible_points": [0.001] + }, + "weight_decay": { + "feasible_points": [1.2859640541025928e-7] + } } diff --git a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json index 0749f96d6..b5b4aad30 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.01897755400372091 - ] - }, - "beta1": { - "feasible_points": [ - 0.9666072782043229 - ] - }, - "beta2": { - "feasible_points": [ - 0.99681600289198 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.015653883841116094 - ] - } -} \ No newline at end of file + "learning_rate": { + "feasible_points": [0.01897755400372091] + }, + "beta1": { + "feasible_points": [0.9666072782043229] + }, + "beta2": { + "feasible_points": [0.99681600289198] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "weight_decay": { + "feasible_points": [0.015653883841116094] + } +} diff --git a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json index d5af3b03e..b69114b88 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.001734480757979605 - ] - }, - "beta1": { - "feasible_points": [ - 0.855609542347586 - ] - }, - "beta2": { - "feasible_points": [ - 0.9834185656478605 - ] - }, - "warmup_steps": { - "feasible_points": [ - 3000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.019843063335529494 - ] - } -} \ No newline at end of file + "learning_rate": { + "feasible_points": [0.001734480757979605] + }, + "beta1": { + "feasible_points": [0.855609542347586] + }, + "beta2": { + "feasible_points": [0.9834185656478605] + }, + "warmup_steps": { + "feasible_points": [3000] + }, + "weight_decay": { + "feasible_points": [0.019843063335529494] + } +} diff --git a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json index b9f83a5ed..e1512c02a 100644 --- a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json @@ -1,27 +1,17 @@ { - "learning_rate": { - "feasible_points": [ - 0.00027866530268792414 - ] - }, - "beta1": { - "feasible_points": [ - 0.9919340993463499 - ] - }, - "beta2": { - "feasible_points": [ - 0.9979843253162892 - ] - }, - "warmup_steps": { - "feasible_points": [ - 6000 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.00032418357325210813 - ] - } -} \ No newline at end of file + "learning_rate": { + "feasible_points": [0.00027866530268792414] + }, + "beta1": { + "feasible_points": [0.9919340993463499] + }, + "beta2": { + "feasible_points": [0.9979843253162892] + }, + "warmup_steps": { + "feasible_points": [6000] + }, + "weight_decay": { + "feasible_points": [0.00032418357325210813] + } +} diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py index c87bdfb7d..14e8155d4 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py @@ -4,37 +4,44 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 + update_params, +) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates an AdamW optimizer and a learning rate schedule.""" del model_state del rng epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, - weight_decay=hyperparameters.weight_decay), + 'optimizer': torch.optim.AdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ), } target_setting_step_hint = int(0.75 * workload.step_hint) optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + target_setting_step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py index 584caff39..a23b835eb 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py @@ -4,40 +4,47 @@ from torch.optim.lr_scheduler import LambdaLR from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_momentum import \ - create_lr_schedule_fn -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.jax_momentum import ( + create_lr_schedule_fn, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 + update_params, +) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=False), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=hyperparameters.beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=False, + ), } # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will # be multiplied with the base lr, so we have to divide by it here. @@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index a9dee1d79..d301a233f 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -8,43 +8,43 @@ from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 + update_params, +) # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): if not 0.0 <= lr: raise ValueError(f'Invalid learning rate: {lr}') if not 0.0 <= eps: @@ -56,7 +56,10 @@ def __init__(self, if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + 'lr': lr, + 'betas': betas, + 'eps': eps, + 'weight_decay': weight_decay, } super().__init__(params, defaults) @@ -64,7 +67,8 @@ def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + state_values[0]['step'] + ) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) @@ -72,10 +76,10 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ self._cuda_graph_capture_health_check() loss = None @@ -103,51 +107,57 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = torch.tensor(0.) + state['step'] = torch.tensor(0.0) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + p, memory_format=torch.preserve_format + ) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) state_steps.append(state['step']) nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + ) return loss -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float): +def nadamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, +): r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ + See NAdamW class for details. + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors' + ) for i, param in enumerate(params): grad = grads[i] @@ -185,28 +195,32 @@ def nadamw(params: List[Tensor], exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8 + ) optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, - weight_decay=hyperparameters.weight_decay), + 'optimizer': NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay, + ), } target_setting_step_hint = int(0.75 * workload.step_hint) optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + target_setting_step_hint, hyperparameters, optimizer_state['optimizer'] + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py index 8e10db4ef..3a6294a28 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py @@ -4,40 +4,47 @@ from torch.optim.lr_scheduler import LambdaLR from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_nesterov import \ - create_lr_schedule_fn -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401 + get_batch_size, +) +from reference_algorithms.target_setting_algorithms.jax_momentum import ( + create_lr_schedule_fn, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 + update_params, +) + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule.""" del model_state del rng # Create optimizer. optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=hyperparameters.learning_rate, - momentum=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay, - nesterov=True), + 'optimizer': torch.optim.SGD( + model_params.parameters(), + lr=hyperparameters.learning_rate, + momentum=hyperparameters.beta1, + weight_decay=hyperparameters.weight_decay, + nesterov=True, + ), } # Create learning rate schedule. target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint, - hyperparameters) + lr_schedule_fn = create_lr_schedule_fn( + target_setting_step_hint, hyperparameters + ) # PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will # be multiplied with the base lr, so we have to divide by it here. @@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float: return lr_schedule_fn(step).item() / hyperparameters.learning_rate optimizer_state['scheduler'] = LambdaLR( - optimizer_state['optimizer'], lr_lambda=_lr_lambda) + optimizer_state['optimizer'], lr_lambda=_lr_lambda + ) return optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index bbfd8b0f2..0bef4548f 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Optional, Tuple -from absl import logging import torch import torch.distributed.nn as dist_nn +from absl import logging from algoperf import spec from algoperf.pytorch_utils import pytorch_setup @@ -13,18 +13,19 @@ def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type @@ -36,26 +37,30 @@ def update_params( optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + hyperparameters.label_smoothing + if hasattr(hyperparameters, 'label_smoothing') + else 0.0 + ) if hasattr(hyperparameters, 'grad_clip'): grad_clip = hyperparameters.grad_clip else: grad_clip = None loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) summed_loss = loss_dict['summed'] n_valid_examples = loss_dict['n_valid_examples'] if USE_PYTORCH_DDP: @@ -68,7 +73,8 @@ def update_params( if grad_clip is not None: torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) + current_model.parameters(), max_norm=grad_clip + ) optimizer_state['optimizer'].step() if 'scheduler' in optimizer_state: optimizer_state['scheduler'].step() @@ -78,32 +84,39 @@ def update_params( with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) if workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) return (optimizer_state, current_param_container, new_model_state) -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del workload del hyperparameters diff --git a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json index f0ef45daa..aee18d976 100644 --- a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.0017486387539278373 - ] - }, - "beta1": { - "feasible_points": [ - 0.9326607383586145 - ] - }, - "beta2": { - "feasible_points": [ - 0.9955159689799007 - ] - }, - "warmup_steps": { - "feasible_points": [ - 1999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.08121616522670176 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.0 - ] - } + "learning_rate": { + "feasible_points": [0.0017486387539278373] + }, + "beta1": { + "feasible_points": [0.9326607383586145] + }, + "beta2": { + "feasible_points": [0.9955159689799007] + }, + "warmup_steps": { + "feasible_points": [1999] + }, + "weight_decay": { + "feasible_points": [0.08121616522670176] + }, + "label_smoothing": { + "feasible_points": [0.0] + } } diff --git a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json index 266cdedbb..e1ce2229f 100644 --- a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.000590120167916659 - ] - }, - "beta1": { - "feasible_points": [ - 0.737199286155609 - ] - }, - "beta2": { - "feasible_points": [ - 0.05919391544031072 - ] - }, - "warmup_steps": { - "feasible_points": [ - 9999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.14128519778326312 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.0 - ] - } + "learning_rate": { + "feasible_points": [0.000590120167916659] + }, + "beta1": { + "feasible_points": [0.737199286155609] + }, + "beta2": { + "feasible_points": [0.05919391544031072] + }, + "warmup_steps": { + "feasible_points": [9999] + }, + "weight_decay": { + "feasible_points": [0.14128519778326312] + }, + "label_smoothing": { + "feasible_points": [0.0] + } } diff --git a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json index d288d9a49..0ed0f832a 100644 --- a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.000872041489644454 - ] - }, - "beta1": { - "feasible_points": [ - 0.45562164405092065 - ] - }, - "beta2": { - "feasible_points": [ - 0.9982167124443476 - ] - }, - "warmup_steps": { - "feasible_points": [ - 4999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.01536114562763022 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.1 - ] - } + "learning_rate": { + "feasible_points": [0.000872041489644454] + }, + "beta1": { + "feasible_points": [0.45562164405092065] + }, + "beta2": { + "feasible_points": [0.9982167124443476] + }, + "warmup_steps": { + "feasible_points": [4999] + }, + "weight_decay": { + "feasible_points": [0.01536114562763022] + }, + "label_smoothing": { + "feasible_points": [0.1] + } } diff --git a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json index 1327bcb38..d1b3045c7 100644 --- a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json @@ -1,32 +1,20 @@ { - "learning_rate": { - "feasible_points": [ - 0.0003477912008450351 - ] - }, - "beta1": { - "feasible_points": [ - 0.9936632117510711 - ] - }, - "beta2": { - "feasible_points": [ - 0.9967873550453692 - ] - }, - "warmup_steps": { - "feasible_points": [ - 9999 - ] - }, - "weight_decay": { - "feasible_points": [ - 0.04120183162940475 - ] - }, - "label_smoothing": { - "feasible_points": [ - 0.0 - ] - } + "learning_rate": { + "feasible_points": [0.0003477912008450351] + }, + "beta1": { + "feasible_points": [0.9936632117510711] + }, + "beta2": { + "feasible_points": [0.9967873550453692] + }, + "warmup_steps": { + "feasible_points": [9999] + }, + "weight_decay": { + "feasible_points": [0.04120183162940475] + }, + "label_smoothing": { + "feasible_points": [0.0] + } } diff --git a/scoring/algoperf_v05/generate_held_out_workloads.py b/scoring/algoperf_v05/generate_held_out_workloads.py index 647dc3c3d..e9ebf6a53 100644 --- a/scoring/algoperf_v05/generate_held_out_workloads.py +++ b/scoring/algoperf_v05/generate_held_out_workloads.py @@ -2,49 +2,51 @@ import os import struct -from absl import app -from absl import flags -from absl import logging import numpy as np +from absl import app, flags, logging flags.DEFINE_integer( - 'held_out_workloads_seed', - None, - 'Random seed for scoring.' - 'AlgoPerf v0.5 seed: 3438810845') -flags.DEFINE_string('output_filename', - 'held_out_workloads.json', - 'Path to file to record sampled held_out workloads.') + 'held_out_workloads_seed', + None, + 'Random seed for scoring.AlgoPerf v0.5 seed: 3438810845', +) +flags.DEFINE_string( + 'output_filename', + 'held_out_workloads.json', + 'Path to file to record sampled held_out workloads.', +) FLAGS = flags.FLAGS HELD_OUT_WORKLOADS = { - 'librispeech': [ - 'librispeech_conformer_attention_temperature', - 'librispeech_conformer_layernorm', - # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure - 'librispeech_deepspeech_no_resnet', - 'librispeech_deepspeech_norm_and_spec_aug', - 'librispeech_deepspeech_tanh' - ], - 'imagenet': [ - 'imagenet_resnet_silu', - 'imagenet_resnet_gelu', - 'imagenet_resnet_large_bn_init', - 'imagenet_vit_glu', - 'imagenet_vit_post_ln', - 'imagenet_vit_map' - ], - 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], - 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], - 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], - 'criteo1tb': [ - 'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet' - ] + 'librispeech': [ + 'librispeech_conformer_attention_temperature', + 'librispeech_conformer_layernorm', + # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure + 'librispeech_deepspeech_no_resnet', + 'librispeech_deepspeech_norm_and_spec_aug', + 'librispeech_deepspeech_tanh', + ], + 'imagenet': [ + 'imagenet_resnet_silu', + 'imagenet_resnet_gelu', + 'imagenet_resnet_large_bn_init', + 'imagenet_vit_glu', + 'imagenet_vit_post_ln', + 'imagenet_vit_map', + ], + 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], + 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], + 'criteo1tb': [ + 'criteo1tb_layernorm', + 'criteo1tb_embed_init', + 'criteo1tb_resnet', + ], } def save_held_out_workloads(held_out_workloads, filename): - with open(filename, "w") as f: + with open(filename, 'w') as f: json.dump(held_out_workloads, f) @@ -63,7 +65,7 @@ def main(_): sampled_index = rng.integers(len(v)) sampled_held_out_workloads.append(v[sampled_index]) - logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") + logging.info(f'Sampled held-out workloads: {sampled_held_out_workloads}') save_held_out_workloads(sampled_held_out_workloads, output_filename) diff --git a/scoring/algoperf_v05/score_submissions.py b/scoring/algoperf_v05/score_submissions.py index 8cc06b15f..6ef931a54 100644 --- a/scoring/algoperf_v05/score_submissions.py +++ b/scoring/algoperf_v05/score_submissions.py @@ -16,57 +16,62 @@ import os import pickle -from absl import app -from absl import flags -from absl import logging import numpy as np import pandas as pd import performance_profile import scoring_utils +from absl import app, flags, logging from tabulate import tabulate flags.DEFINE_string( - 'submission_directory', - None, - 'Path to submission directory containing experiment directories.') + 'submission_directory', + None, + 'Path to submission directory containing experiment directories.', +) flags.DEFINE_string( - 'output_dir', - 'scoring_results', - 'Path to save performance profile artifacts, submission_summaries and results files.' + 'output_dir', + 'scoring_results', + 'Path to save performance profile artifacts, submission_summaries and results files.', +) +flags.DEFINE_boolean( + 'compute_performance_profiles', + False, + 'Whether or not to compute the performance profiles.', ) -flags.DEFINE_boolean('compute_performance_profiles', - False, - 'Whether or not to compute the performance profiles.') flags.DEFINE_boolean( - 'strict', - False, - 'Whether to enforce scoring criteria on variant performance and on' - '5-trial median performance. Note that during official scoring this ' - 'flag will be set to True.') + 'strict', + False, + 'Whether to enforce scoring criteria on variant performance and on' + '5-trial median performance. Note that during official scoring this ' + 'flag will be set to True.', +) flags.DEFINE_boolean( - 'self_tuning_ruleset', - False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset') + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset', +) flags.DEFINE_string( - 'save_results_to_filename', - None, - 'Filename to save the processed results that are fed into the performance profile functions.' + 'save_results_to_filename', + None, + 'Filename to save the processed results that are fed into the performance profile functions.', ) flags.DEFINE_string( - 'load_results_from_filename', - None, - 'Filename to load processed results from that are fed into performance profile functions' + 'load_results_from_filename', + None, + 'Filename to load processed results from that are fed into performance profile functions', ) flags.DEFINE_string( - 'exclude_submissions', - '', - 'Optional comma seperated list of names of submissions to exclude from scoring.' + 'exclude_submissions', + '', + 'Optional comma seperated list of names of submissions to exclude from scoring.', ) FLAGS = flags.FLAGS def get_summary_df(workload, workload_df, include_test_split=False): - validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + validation_metric, validation_target = ( + scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + ) is_minimized = performance_profile.check_if_minimized(validation_metric) target_op = operator.le if is_minimized else operator.ge @@ -79,47 +84,69 @@ def get_summary_df(workload, workload_df, include_test_split=False): summary_df['val target metric name'] = validation_metric summary_df['val target metric value'] = validation_target - summary_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + summary_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) summary_df['best metric value on val'] = workload_df[validation_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on val'] = workload_df[validation_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on val']], - axis=1) - workload_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + lambda x: x['accumulated_submission_time'][x['index best eval on val']], + axis=1, + ) + workload_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) workload_df['index to target on val'] = workload_df.apply( - lambda x: np.argmax(target_op(x[validation_metric], validation_target)) - if x['val target reached'] else np.nan, - axis=1) + lambda x: np.argmax(target_op(x[validation_metric], validation_target)) + if x['val target reached'] + else np.nan, + axis=1, + ) summary_df['time to target on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][int(x[ - 'index to target on val'])] if x['val target reached'] else np.inf, - axis=1) + lambda x: x['accumulated_submission_time'][int(x['index to target on val'])] + if x['val target reached'] + else np.inf, + axis=1, + ) # test metrics if include_test_split: - test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test') + test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( + workload, split='test' + ) summary_df['test target metric name'] = test_metric summary_df['test target metric value'] = test_target - summary_df['test target reached'] = workload_df[test_metric].apply( - lambda x: target_op(x, test_target)).apply(np.any) + summary_df['test target reached'] = ( + workload_df[test_metric] + .apply(lambda x: target_op(x, test_target)) + .apply(np.any) + ) summary_df['best metric value on test'] = workload_df[test_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on test'] = workload_df[test_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on test (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on test'] - ], - axis=1) + lambda x: x['accumulated_submission_time'][x['index best eval on test']], + axis=1, + ) summary_df['time to target on test (s)'] = summary_df.apply( - lambda x: x['time to best eval on test (s)'] - if x['test target reached'] else np.inf, - axis=1) + lambda x: x['time to best eval on test (s)'] + if x['test target reached'] + else np.inf, + axis=1, + ) return summary_df @@ -133,7 +160,8 @@ def get_submission_summary(df, include_test_split=True): print(df) for workload, group in df.groupby('workload'): summary_df = get_summary_df( - workload, group, include_test_split=include_test_split) + workload, group, include_test_split=include_test_split + ) dfs.append(summary_df) df = pd.concat(dfs) @@ -164,61 +192,64 @@ def main(_): # Optionally read results to filename if FLAGS.load_results_from_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), - 'rb') as f: + os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb' + ) as f: results = pickle.load(f) else: for team in os.listdir(FLAGS.submission_directory): for submission in os.listdir( - os.path.join(FLAGS.submission_directory, team)): + os.path.join(FLAGS.submission_directory, team) + ): print(submission) if submission in FLAGS.exclude_submissions.split(','): continue - experiment_path = os.path.join(FLAGS.submission_directory, - team, - submission) + experiment_path = os.path.join( + FLAGS.submission_directory, team, submission + ) df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df summary_df = get_submission_summary(df) with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), - 'w') as fout: + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), - 'wb') as f: + os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb' + ) as f: pickle.dump(results, f) if not FLAGS.strict: logging.warning( - 'You are running with strict=False. This will relax ' - 'scoring criteria on the held-out workloads, number of trials and number ' - 'of studies. Your score may not be an accurate representation ' - 'under competition scoring rules. To enforce the criteria set strict=True.' + 'You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.' ) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( - results, - time_col='score', - min_tau=1.0, - max_tau=4.0, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - self_tuning_ruleset=FLAGS.self_tuning_ruleset, - strict=FLAGS.strict, - output_dir=FLAGS.output_dir, + results, + time_col='score', + min_tau=1.0, + max_tau=4.0, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + strict=FLAGS.strict, + output_dir=FLAGS.output_dir, ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( - performance_profile_df, 'score', save_dir=FLAGS.output_dir) + performance_profile_df, 'score', save_dir=FLAGS.output_dir + ) performance_profile_str = tabulate( - performance_profile_df.T, headers='keys', tablefmt='psql') + performance_profile_df.T, headers='keys', tablefmt='psql' + ) logging.info(f'Performance profile:\n {performance_profile_str}') scores = compute_leaderboard_score(performance_profile_df) scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv')) diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py index d0e5bf70b..1740a6dce 100644 --- a/scoring/compute_speedups.py +++ b/scoring/compute_speedups.py @@ -2,39 +2,39 @@ import pickle -from absl import app -from absl import flags import numpy as np import pandas as pd -from performance_profile import BASE_WORKLOADS -from performance_profile import get_workloads_time_to_target +from absl import app, flags +from performance_profile import BASE_WORKLOADS, get_workloads_time_to_target from scipy import stats flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.') flags.DEFINE_string( - 'base', - 'prize_qualification_baseline', - 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.' + 'base', + 'prize_qualification_baseline', + 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.', ) flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.') -flags.DEFINE_boolean('self_tuning_ruleset', - False, - 'Whether the self-tuning ruleset is being scored.') -flags.DEFINE_boolean('save_results', - False, - 'Whether to save the results to disk.') +flags.DEFINE_boolean( + 'self_tuning_ruleset', + False, + 'Whether the self-tuning ruleset is being scored.', +) +flags.DEFINE_boolean( + 'save_results', False, 'Whether to save the results to disk.' +) FLAGS = flags.FLAGS # These are the old budgets, used in the first iteration of the competition. MAX_BUDGETS = { - 'criteo1tb': 7703, - 'fastmri': 8859, - 'imagenet_resnet': 63_008, - 'imagenet_vit': 77_520, - 'librispeech_conformer': 61_068, - 'librispeech_deepspeech': 55_506, - 'ogbg': 18_477, - 'wmt': 48_151, + 'criteo1tb': 7703, + 'fastmri': 8859, + 'imagenet_resnet': 63_008, + 'imagenet_vit': 77_520, + 'librispeech_conformer': 61_068, + 'librispeech_deepspeech': 55_506, + 'ogbg': 18_477, + 'wmt': 48_151, } @@ -63,16 +63,16 @@ def compute_speedup(): # Compute median over runtimes for both training algorithms base_results = get_workloads_time_to_target( - results[FLAGS.base], - FLAGS.base, - time_col="score", - self_tuning_ruleset=FLAGS.self_tuning_ruleset, + results[FLAGS.base], + FLAGS.base, + time_col='score', + self_tuning_ruleset=FLAGS.self_tuning_ruleset, ) comparison_results = get_workloads_time_to_target( - results[FLAGS.comparison], - FLAGS.comparison, - time_col="score", - self_tuning_ruleset=FLAGS.self_tuning_ruleset, + results[FLAGS.comparison], + FLAGS.comparison, + time_col='score', + self_tuning_ruleset=FLAGS.self_tuning_ruleset, ) # Merge results @@ -85,20 +85,23 @@ def compute_speedup(): merged_results = merged_results.apply(replace_inf, axis=1) # Compute speedup - merged_results['speedup'] = merged_results[ - f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}'] + merged_results['speedup'] = ( + merged_results[f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}'] + ) speedups = merged_results['speedup'].to_numpy() mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups print(merged_results, end='\n\n') print( - f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}" + f'Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1 - mean_speedup):.1%}' ) if FLAGS.save_results: # Optionally save results to disk - print("Saving results to disk...") - filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv' + print('Saving results to disk...') + filename = ( + f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1 - mean_speedup):.1%}.csv' + ) merged_results.to_csv(filename) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 05026a0c7..4f2ae9c57 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -25,21 +25,22 @@ The keys in this dictionary should match the workload identifiers used in the dictionary of submissions. """ + import itertools import json import operator import os import re -from absl import logging import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd +from absl import logging from tabulate import tabulate -from algoperf.workloads.workloads import get_base_workload_name import algoperf.workloads.workloads as workloads_registry +from algoperf.workloads.workloads import get_base_workload_name from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS @@ -49,9 +50,9 @@ # Open json file to read heldout workloads # TODO: This probably shouldn't be hardcoded but passed as an argument.\ try: - with open("held_out_workloads_algoperf_v05.json", "r") as f: + with open('held_out_workloads_algoperf_v05.json', 'r') as f: HELDOUT_WORKLOADS = json.load(f) -except: +except FileNotFoundError: HELDOUT_WORKLOADS = None # These global variables have to be set according to the current set of @@ -64,22 +65,22 @@ NUM_STUDIES = 3 MIN_EVAL_METRICS = [ - 'ce_loss', - 'error_rate', - 'ctc_loss', - 'wer', - 'l1_loss', - 'loss', + 'ce_loss', + 'error_rate', + 'ctc_loss', + 'wer', + 'l1_loss', + 'loss', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] -#MPL params +# MPL params mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches mpl.rcParams['font.family'] = 'serif' -mpl.rcParams['font.serif'] = [ - 'Times New Roman' -] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice +mpl.rcParams['font.serif'] = ['Times New Roman'] + mpl.rcParams[ + 'font.serif' +] # Add Times New Roman as first choice mpl.rcParams['font.size'] = 22 mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures @@ -87,16 +88,17 @@ mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed mpl.rcParams['axes.prop_cycle'] = mpl.cycler( - color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", - "#9467bd"]) # Example color cycle (consider ColorBrewer or viridis) + color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'] +) # Example color cycle (consider ColorBrewer or viridis) mpl.rcParams['axes.labelsize'] = 22 # Axis label font size mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size mpl.rcParams['ytick.labelsize'] = 20 # Legends and Gridlines mpl.rcParams['legend.fontsize'] = 20 # Legend font size -mpl.rcParams[ - 'legend.loc'] = 'best' # Let matplotlib decide the best legend location +mpl.rcParams['legend.loc'] = ( + 'best' # Let matplotlib decide the best legend location +) mpl.rcParams['axes.grid'] = True # Enable grid mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency @@ -113,7 +115,8 @@ def generate_eval_cols(metrics): MINIMIZE_REGISTRY = {k: True for k in generate_eval_cols(MIN_EVAL_METRICS)} MINIMIZE_REGISTRY.update( - {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)}) + {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)} +) MINIMIZE_REGISTRY['train_cost'] = True @@ -125,13 +128,15 @@ def check_if_minimized(col_name): if col in col_name: return MINIMIZE_REGISTRY[col] - raise ValueError(f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' - 'either a column name or a substring of a column name.') + raise ValueError( + f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' + 'either a column name or a substring of a column name.' + ) -def get_best_trial_index(workload_df, - validation_metric, - validation_target=None): +def get_best_trial_index( + workload_df, validation_metric, validation_target=None +): """Get the eval index in which a workload reaches the target metric_col. Args: @@ -150,7 +155,8 @@ def get_best_trial_index(workload_df, op = operator.le if is_minimized else operator.ge validation_target_reached = validation_series.apply( - lambda x: op(x, validation_target)) + lambda x: op(x, validation_target) + ) target_reached = pd.Series(validation_target_reached) # Remove trials that never reach the target @@ -166,12 +172,14 @@ def get_best_trial_index(workload_df, return trial, index_reached[trial] -def get_workloads_time_to_target(submission, - submission_name, - time_col='global_step', - verbosity=1, - self_tuning_ruleset=False, - strict=False): +def get_workloads_time_to_target( + submission, + submission_name, + time_col='global_step', + verbosity=1, + self_tuning_ruleset=False, + strict=False, +): """Get times to target for each workload in a submission. Args: @@ -191,60 +199,72 @@ def get_workloads_time_to_target(submission, if num_workloads != NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS: if strict: raise ValueError( - f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' - f'but found {num_workloads} workloads for {submission_name}.') - logging.warning( f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' - f'but found {num_workloads} workloads for {submission_name}.') + f'but found {num_workloads} workloads for {submission_name}.' + ) + logging.warning( + f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' + f'but found {num_workloads} workloads for {submission_name}.' + ) # For each workload get submission time get the submission times to target. for workload, group in submission.groupby('workload'): - validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload) + validation_metric, validation_target = ( + scoring_utils.get_workload_metrics_and_targets(workload) + ) # Check number of studies time_vals_per_study = [] num_studies = len(group.groupby('study')) if num_studies != NUM_STUDIES: if strict: - raise ValueError(f'Expecting {NUM_STUDIES} studies for workload ' - f'{workload} but found {num_studies} studies ' - f'for {submission_name}.') + raise ValueError( + f'Expecting {NUM_STUDIES} studies for workload ' + f'{workload} but found {num_studies} studies ' + f'for {submission_name}.' + ) else: - logging.warning(f'Expecting {NUM_STUDIES} studies for workload ' - f'{workload} but found {num_studies} studies ' - f'for {submission_name}.') + logging.warning( + f'Expecting {NUM_STUDIES} studies for workload ' + f'{workload} but found {num_studies} studies ' + f'for {submission_name}.' + ) # For each study check trials for study, group in group.groupby('study'): - # Check number of trials per study num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: if strict: raise ValueError( - f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials ' - f'for {submission_name}.') + f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials ' + f'for {submission_name}.' + ) else: logging.warning( - f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials ' - f'for {submission_name}.') + f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials ' + f'for {submission_name}.' + ) # Get trial and time index that reaches target trial_idx, time_idx = get_best_trial_index( - group, validation_metric, validation_target) + group, validation_metric, validation_target + ) if time_idx > -1: time_val = group[time_col].loc[trial_idx][time_idx] else: time_val = float('inf') time_vals_per_study.append(time_val) - workloads.append({ + workloads.append( + { 'submission': submission_name, 'workload': re.sub(r'_(jax|pytorch)$', '', workload), time_col: np.median(time_vals_per_study), - }) + } + ) df = pd.DataFrame.from_records(workloads) df = df.pivot(index='submission', columns='workload', values=time_col) @@ -252,7 +272,6 @@ def get_workloads_time_to_target(submission, def variant_criteria_filter(base_workload, variant_workload): - def filter(x): try: if x[variant_workload] == np.inf: @@ -269,17 +288,19 @@ def filter(x): return filter -def compute_performance_profiles(submissions, - time_col='global_step', - min_tau=1.0, - max_tau=None, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - strict=False, - self_tuning_ruleset=False, - output_dir=None): +def compute_performance_profiles( + submissions, + time_col='global_step', + min_tau=1.0, + max_tau=None, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + strict=False, + self_tuning_ruleset=False, + output_dir=None, +): """Compute performance profiles for a set of submission by some time column. Args: @@ -308,16 +329,20 @@ def compute_performance_profiles(submissions, for submission_tag, submission in submissions.items(): logging.info( - f'\nComputing performance profile with respect to `{time_col}` for ' - f'{submission_tag}') + f'\nComputing performance profile with respect to `{time_col}` for ' + f'{submission_tag}' + ) # Get time to targets for each submission across studies and trials dfs.append( - get_workloads_time_to_target(submission, - submission_tag, - time_col, - verbosity, - self_tuning_ruleset, - strict)) + get_workloads_time_to_target( + submission, + submission_tag, + time_col, + verbosity, + self_tuning_ruleset, + strict, + ) + ) df = pd.concat(dfs) # Restrict to base and sampled held-out workloads # (ignore the additional workload variants of the baseline @@ -335,7 +360,8 @@ def compute_performance_profiles(submissions, # If base do not have finite score set variant score to inf base_workload = get_base_workload_name(workload) df[workload] = df.apply( - variant_criteria_filter(workload, base_workload), axis=1) + variant_criteria_filter(workload, base_workload), axis=1 + ) # Set score to inf if not within 4x of fastest submission best_scores = df.min(axis=0) @@ -347,17 +373,20 @@ def compute_performance_profiles(submissions, # If variants do not have finite score set base_workload score to inf base_workload = get_base_workload_name(workload) df[base_workload] = df.apply( - variant_criteria_filter(base_workload, workload), axis=1) + variant_criteria_filter(base_workload, workload), axis=1 + ) df = df[BASE_WORKLOADS] if verbosity > 0: logging.info('\n`{time_col}` to reach target:') - with pd.option_context('display.max_rows', - None, - 'display.max_columns', - None, - 'display.width', - 1000): + with pd.option_context( + 'display.max_rows', + None, + 'display.max_columns', + None, + 'display.width', + 1000, + ): logging.info(df) # Divide by the fastest. @@ -368,12 +397,14 @@ def compute_performance_profiles(submissions, if verbosity > 0: logging.info('\n`{time_col}` to reach target normalized to best:') - with pd.option_context('display.max_rows', - None, - 'display.max_columns', - None, - 'display.width', - 1000): + with pd.option_context( + 'display.max_rows', + None, + 'display.max_columns', + None, + 'display.width', + 1000, + ): logging.info(df) # If no max_tau is supplied, choose the value of tau that would plot all non @@ -385,7 +416,8 @@ def compute_performance_profiles(submissions, points = np.linspace(min_tau, max_tau, num=num_points) elif scale == 'log': points = np.logspace( - np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0) + np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0 + ) def rho(r, tau): return (r <= tau).sum(axis=1) / NUM_BASE_WORKLOADS @@ -431,11 +463,9 @@ def maybe_save_df_to_csv(save_dir, df, path, **to_csv_kwargs): df.to_csv(fout, **to_csv_kwargs) -def plot_performance_profiles(perf_df, - df_col, - scale='linear', - save_dir=None, - figsize=(30, 10)): +def plot_performance_profiles( + perf_df, df_col, scale='linear', save_dir=None, figsize=(30, 10) +): """Plot performance profiles. Args: @@ -462,6 +492,6 @@ def plot_performance_profiles(perf_df, fig.legend(bbox_to_anchor=(1.0, 1.0)) plt.tight_layout() maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}') - maybe_save_df_to_csv(save_dir, - perf_df, - f'performance_profile_{df_col_display}.csv') + maybe_save_df_to_csv( + save_dir, perf_df, f'performance_profile_{df_col_display}.csv' + ) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index f07dc8cdd..b48509f02 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -17,57 +17,62 @@ import os import pickle -from absl import app -from absl import flags -from absl import logging import numpy as np import pandas as pd import performance_profile import scoring_utils +from absl import app, flags, logging from tabulate import tabulate flags.DEFINE_string( - 'submission_directory', - None, - 'Path to submission directory containing experiment directories.') + 'submission_directory', + None, + 'Path to submission directory containing experiment directories.', +) flags.DEFINE_string( - 'output_dir', - 'scoring_results', - 'Path to save performance profile artifacts, submission_summaries and results files.' + 'output_dir', + 'scoring_results', + 'Path to save performance profile artifacts, submission_summaries and results files.', +) +flags.DEFINE_boolean( + 'compute_performance_profiles', + False, + 'Whether or not to compute the performance profiles.', ) -flags.DEFINE_boolean('compute_performance_profiles', - False, - 'Whether or not to compute the performance profiles.') flags.DEFINE_boolean( - 'strict', - False, - 'Whether to enforce scoring criteria on variant performance and on' - '5-trial median performance. Note that during official scoring this ' - 'flag will be set to True.') + 'strict', + False, + 'Whether to enforce scoring criteria on variant performance and on' + '5-trial median performance. Note that during official scoring this ' + 'flag will be set to True.', +) flags.DEFINE_boolean( - 'self_tuning_ruleset', - False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset') + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset', +) flags.DEFINE_string( - 'save_results_to_filename', - None, - 'Filename to save the processed results that are fed into the performance profile functions.' + 'save_results_to_filename', + None, + 'Filename to save the processed results that are fed into the performance profile functions.', ) flags.DEFINE_string( - 'load_results_from_filename', - None, - 'Filename to load processed results from that are fed into performance profile functions' + 'load_results_from_filename', + None, + 'Filename to load processed results from that are fed into performance profile functions', ) flags.DEFINE_string( - 'exclude_submissions', - '', - 'Optional comma seperated list of names of submissions to exclude from scoring.' + 'exclude_submissions', + '', + 'Optional comma seperated list of names of submissions to exclude from scoring.', ) FLAGS = flags.FLAGS def get_summary_df(workload, workload_df, include_test_split=False): - validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + validation_metric, validation_target = ( + scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + ) is_minimized = performance_profile.check_if_minimized(validation_metric) target_op = operator.le if is_minimized else operator.ge @@ -80,47 +85,69 @@ def get_summary_df(workload, workload_df, include_test_split=False): summary_df['val target metric name'] = validation_metric summary_df['val target metric value'] = validation_target - summary_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + summary_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) summary_df['best metric value on val'] = workload_df[validation_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on val'] = workload_df[validation_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on val']], - axis=1) - workload_df['val target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) + lambda x: x['accumulated_submission_time'][x['index best eval on val']], + axis=1, + ) + workload_df['val target reached'] = ( + workload_df[validation_metric] + .apply(lambda x: target_op(x, validation_target)) + .apply(np.any) + ) workload_df['index to target on val'] = workload_df.apply( - lambda x: np.argmax(target_op(x[validation_metric], validation_target)) - if x['val target reached'] else np.nan, - axis=1) + lambda x: np.argmax(target_op(x[validation_metric], validation_target)) + if x['val target reached'] + else np.nan, + axis=1, + ) summary_df['time to target on val (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][int(x[ - 'index to target on val'])] if x['val target reached'] else np.inf, - axis=1) + lambda x: x['accumulated_submission_time'][int(x['index to target on val'])] + if x['val target reached'] + else np.inf, + axis=1, + ) # test metrics if include_test_split: - test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test') + test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( + workload, split='test' + ) summary_df['test target metric name'] = test_metric summary_df['test target metric value'] = test_target - summary_df['test target reached'] = workload_df[test_metric].apply( - lambda x: target_op(x, test_target)).apply(np.any) + summary_df['test target reached'] = ( + workload_df[test_metric] + .apply(lambda x: target_op(x, test_target)) + .apply(np.any) + ) summary_df['best metric value on test'] = workload_df[test_metric].apply( - lambda x: best_op(x)) + lambda x: best_op(x) + ) workload_df['index best eval on test'] = workload_df[test_metric].apply( - lambda x: idx_op(x)) + lambda x: idx_op(x) + ) summary_df['time to best eval on test (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval on test'] - ], - axis=1) + lambda x: x['accumulated_submission_time'][x['index best eval on test']], + axis=1, + ) summary_df['time to target on test (s)'] = summary_df.apply( - lambda x: x['time to best eval on test (s)'] - if x['test target reached'] else np.inf, - axis=1) + lambda x: x['time to best eval on test (s)'] + if x['test target reached'] + else np.inf, + axis=1, + ) return summary_df @@ -134,7 +161,8 @@ def get_submission_summary(df, include_test_split=True): print(df) for workload, group in df.groupby('workload'): summary_df = get_summary_df( - workload, group, include_test_split=include_test_split) + workload, group, include_test_split=include_test_split + ) dfs.append(summary_df) df = pd.concat(dfs) @@ -161,13 +189,13 @@ def compute_leaderboard_score(df, normalize=True): def main(_): results = {} os.makedirs(FLAGS.output_dir, exist_ok=True) - logging.info(f"Scoring submissions in {FLAGS.submission_directory}") + logging.info(f'Scoring submissions in {FLAGS.submission_directory}') # Optionally read results to filename if FLAGS.load_results_from_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), - 'rb') as f: + os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb' + ) as f: results = pickle.load(f) else: for submission in os.listdir(FLAGS.submission_directory): @@ -179,44 +207,46 @@ def main(_): results[submission] = df summary_df = get_submission_summary(df) with open( - os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), - 'w') as fout: + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w' + ) as fout: summary_df.to_csv(fout) # Optionally save results to filename if FLAGS.save_results_to_filename: with open( - os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), - 'wb') as f: + os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb' + ) as f: pickle.dump(results, f) if not FLAGS.strict: logging.warning( - 'You are running with strict=False. This will relax ' - 'scoring criteria on the held-out workloads, number of trials and number ' - 'of studies. Your score may not be an accurate representation ' - 'under competition scoring rules. To enforce the criteria set strict=True.' + 'You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.' ) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( - results, - time_col='score', - min_tau=1.0, - max_tau=4.0, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - self_tuning_ruleset=FLAGS.self_tuning_ruleset, - strict=FLAGS.strict, - output_dir=FLAGS.output_dir, + results, + time_col='score', + min_tau=1.0, + max_tau=4.0, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + strict=FLAGS.strict, + output_dir=FLAGS.output_dir, ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( - performance_profile_df, 'score', save_dir=FLAGS.output_dir) + performance_profile_df, 'score', save_dir=FLAGS.output_dir + ) performance_profile_str = tabulate( - performance_profile_df.T, headers='keys', tablefmt='psql') + performance_profile_df.T, headers='keys', tablefmt='psql' + ) logging.info(f'Performance profile:\n {performance_profile_str}') scores = compute_leaderboard_score(performance_profile_df) scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv')) diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index ac513816e..5be6c790c 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -4,8 +4,8 @@ import os import re -from absl import logging import pandas as pd +from absl import logging import algoperf.workloads.workloads as workloads_registry @@ -13,7 +13,7 @@ METRICS_LINE_REGEX = '(.*) Metrics: ({.*})' TRIAL_DIR_REGEX = 'trial_(\d+)' MEASUREMENTS_FILENAME = 'eval_measurements.csv' -TIMESTAMP = r"-\d{4}(-\d{2}){5}" +TIMESTAMP = r'-\d{4}(-\d{2}){5}' WORKLOADS = workloads_registry.WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' @@ -22,12 +22,11 @@ #### File IO helper functions ### def get_logfile_paths(logdir): - """Gets all files ending in .log in logdir - """ + """Gets all files ending in .log in logdir""" filenames = os.listdir(logdir) logfile_paths = [] for f in filenames: - if f.endswith(".log"): + if f.endswith('.log'): f = os.path.join(logdir, f) logfile_paths.append(f) return logfile_paths @@ -36,23 +35,23 @@ def get_logfile_paths(logdir): ### Logfile reading helper functions ### def decode_metrics_line(line): """Convert metrics line to dict. - Args: - line: str - - Returns: - dict_of_lists: dict where keys are metric names and vals - are lists of values. - e.g. {'loss':[5.1, 3.2, 1.0], - 'step':[100, 200, 300]} - """ + Args: + line: str + + Returns: + dict_of_lists: dict where keys are metric names and vals + are lists of values. + e.g. {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]} + """ eval_results = [] dict_str = re.match(METRICS_LINE_REGEX, line).group(2) - dict_str = dict_str.replace("'", "\"") - dict_str = dict_str.replace("(", "") - dict_str = dict_str.replace(")", "") - dict_str = dict_str.replace("DeviceArray", "") - dict_str = dict_str.replace(", dtype=float32", "") - dict_str = dict_str.replace("nan", "0") + dict_str = dict_str.replace("'", '"') + dict_str = dict_str.replace('(', '') + dict_str = dict_str.replace(')', '') + dict_str = dict_str.replace('DeviceArray', '') + dict_str = dict_str.replace(', dtype=float32', '') + dict_str = dict_str.replace('nan', '0') metrics_dict = json.loads(dict_str) for item in metrics_dict['eval_results']: if isinstance(item, dict): @@ -73,18 +72,18 @@ def decode_metrics_line(line): def get_trials_dict(logfile): - """Get a dict of dicts with metrics for each - tuning run. - - Returns: - trials_dict: Dict of dicts where outer dict keys - are trial indices and inner dict key-value pairs - are metrics and list of values. - e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0], - 'step':[100, 200, 300]}, - 'trial_1': {'loss':[5.1, 3.2, 1.0], - 'step':[100, 200, 300]}} - """ + """Get a dict of dicts with metrics for each + tuning run. + + Returns: + trials_dict: Dict of dicts where outer dict keys + are trial indices and inner dict key-value pairs + are metrics and list of values. + e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]}, + 'trial_1': {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]}} + """ trial = 0 metrics_lines = {} with open(logfile, 'r') as f: @@ -100,16 +99,16 @@ def get_trials_dict(logfile): ### Results formatting helper functions ### def get_trials_df_dict(logfile): - """Get a dict with dataframes with metrics for each - tuning run. - Preferable format for saving dataframes for tables. - Args: - logfile: str path to logfile. - - Returns: - DataFrame where indices are index of eval and - columns are metric names. - """ + """Get a dict with dataframes with metrics for each + tuning run. + Preferable format for saving dataframes for tables. + Args: + logfile: str path to logfile. + + Returns: + DataFrame where indices are index of eval and + columns are metric names. + """ trials_dict = get_trials_dict(logfile) trials_df_dict = {} for trial, metrics in trials_dict.items(): @@ -119,20 +118,20 @@ def get_trials_df_dict(logfile): def get_trials_df(logfile): """Gets a df of per trial results from a logfile. - Args: - experiment_dir: str - - Returns: - df: DataFrame where indices are trials, columns are - metric names and values are lists. - e.g - +---------+-----------------+-----------------+ - | | loss | step | - |---------+-----------------+-----------------| - | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] | - | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] | - +---------+-----------------+-----------------+ - """ + Args: + experiment_dir: str + + Returns: + df: DataFrame where indices are trials, columns are + metric names and values are lists. + e.g + +---------+-----------------+-----------------+ + | | loss | step | + |---------+-----------------+-----------------| + | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] | + | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] | + +---------+-----------------+-----------------+ + """ trials_dict = get_trials_dict(logfile) df = pd.DataFrame(trials_dict).transpose() return df @@ -141,13 +140,13 @@ def get_trials_df(logfile): ## Get scoring code def get_experiment_df(experiment_dir): """Gets a df of per trial results from an experiment dir. - The output df can be provided as input to - performance_profile.compute_performance_profiles. + The output df can be provided as input to + performance_profile.compute_performance_profiles. Args: - experiment_dir: path to experiment directory containing - results for workloads. Measurements from experiments - sharing the same prefix but different timestamps are - collected together. + experiment_dir: path to experiment directory containing + results for workloads. Measurements from experiments + sharing the same prefix but different timestamps are + collected together. The directory structure is assumed to be: + experiment_dir + study @@ -156,9 +155,9 @@ def get_experiment_df(experiment_dir): - eval_measurements.csv Returns: - df: DataFrame where indices are trials, columns are + df: DataFrame where indices are trials, columns are metric names and values are lists of length num evals. - e.g + e.g +----+-----------+--------+----------------------------+--------------------+--------------------+ | | workload | study |trial | validation/accuracy| score | |----+-----------+--------+----------------------------+--------------------+--------------------| @@ -167,39 +166,41 @@ def get_experiment_df(experiment_dir): """ df = pd.DataFrame() paths = filter( - lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir, - glob.glob(f"{experiment_dir}*")) + lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir, + glob.glob(f'{experiment_dir}*'), + ) for experiment_dir in paths: study_dirs = os.listdir(experiment_dir) for study_dir in study_dirs: workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir)) workload_dirs = [ - w for w in workload_dirs - if os.path.isdir(os.path.join(experiment_dir, study_dir, w)) + w + for w in workload_dirs + if os.path.isdir(os.path.join(experiment_dir, study_dir, w)) ] print(workload_dirs) for workload in workload_dirs: data = { - 'workload': workload, + 'workload': workload, } logging.info(os.path.join(experiment_dir, study_dir, workload)) trial_dirs = [ - t for t in os.listdir( - os.path.join(experiment_dir, study_dir, workload)) - if re.match(TRIAL_DIR_REGEX, t) + t + for t in os.listdir(os.path.join(experiment_dir, study_dir, workload)) + if re.match(TRIAL_DIR_REGEX, t) ] for trial in trial_dirs: eval_measurements_filepath = os.path.join( - experiment_dir, - study_dir, - workload, - trial, - MEASUREMENTS_FILENAME, + experiment_dir, + study_dir, + workload, + trial, + MEASUREMENTS_FILENAME, ) try: trial_df = pd.read_csv(eval_measurements_filepath) - except FileNotFoundError as e: + except FileNotFoundError: logging.info(f'Could not read {eval_measurements_filepath}') continue data['trial'] = (trial, experiment_dir) @@ -221,14 +222,16 @@ def get_workload_metrics_and_targets(workload, split='validation'): # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( - BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + f'{framework}', - 'workload.py') + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) workload_init_kwargs = {} workload_obj = workloads_registry.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs=workload_init_kwargs) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) metric_name = workload_obj.target_metric_name if split == 'validation': metric = f'validation/{metric_name}' diff --git a/scoring/test_performance_profile.py b/scoring/test_performance_profile.py deleted file mode 100644 index 166c82d09..000000000 --- a/scoring/test_performance_profile.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -from absl.testing import absltest - -from scoring import performance_profile -from scoring import scoring_utils - - -class Test(absltest.TestCase): - - def test_get_workloads_time_to_target(self): - # TODO(kasimbeg) - pass - - def test_get_best_trial_index(self): - # TODO(kasimbeg) - pass - - def test_compute_performance_profiles(self): - # TODO(kasimbeg) - pass - - -if __name__ == '__main__': - absltest.main() diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py index 7509e3e46..64e141976 100644 --- a/scoring/test_scoring_utils.py +++ b/scoring/test_scoring_utils.py @@ -1,8 +1,5 @@ -import os - from absl.testing import absltest -from scoring import performance_profile from scoring import scoring_utils TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log' @@ -11,7 +8,6 @@ class Test(absltest.TestCase): - def test_get_trials_dict(self): trials_dict = scoring_utils.get_trials_dict(TEST_LOGFILE) self.assertEqual(len(trials_dict['1']['global_step']), NUM_EVALS) diff --git a/scoring/utils/package_logs.py b/scoring/utils/package_logs.py index 074075abf..e341570a1 100644 --- a/scoring/utils/package_logs.py +++ b/scoring/utils/package_logs.py @@ -3,11 +3,11 @@ python3 package_logs.py --experiment_dir --destination_dir """ + import os import shutil -from absl import app -from absl import flags +from absl import app, flags flags.DEFINE_string('experiment_dir', None, 'Path to experiment.') flags.DEFINE_string('destination_dir', None, 'Path to save submission logs') @@ -17,10 +17,10 @@ def move_logs(experiment_dir, destination_dir): """Copy files from experiment path to destination directory. - Args: - experiment_dir: Path to experiment dir. - destination_dir: Path to destination dir. - """ + Args: + experiment_dir: Path to experiment dir. + destination_dir: Path to destination dir. + """ if not os.path.exists(experiment_dir): raise IOError(f'Directory does not exist {destination_dir}') diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 683fb3c63..e2de01130 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -16,101 +16,111 @@ import subprocess import time -from absl import app -from absl import flags -from absl import logging +from absl import app, flags, logging +import docker from algoperf import random_utils as prng from algoperf.workloads.workloads import get_base_workload_name -import docker flags.DEFINE_string( - 'docker_image_url', - 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev', - 'URL to docker image') + 'docker_image_url', + 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev', + 'URL to docker image', +) flags.DEFINE_integer( - 'run_percentage', - 100, - 'Percentage of max num steps to run for.' - 'Must set the flag enable_step_budget to True for this to take effect.') -flags.DEFINE_string('experiment_name', - 'my_experiment', - 'Name of top sub directory in experiment dir.') -flags.DEFINE_boolean('rsync_data', - True, - 'Whether or not to transfer the data from GCP w rsync.') + 'run_percentage', + 100, + 'Percentage of max num steps to run for.' + 'Must set the flag enable_step_budget to True for this to take effect.', +) +flags.DEFINE_string( + 'experiment_name', + 'my_experiment', + 'Name of top sub directory in experiment dir.', +) +flags.DEFINE_boolean( + 'rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.' +) flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') flags.DEFINE_string( - 'submission_path', - 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', - 'Path to reference submission.') + 'submission_path', + 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', + 'Path to reference submission.', +) flags.DEFINE_string( - 'tuning_search_space', - 'prize_qualification_baselines/external_tuning/tuning_search_space.json', - 'Path to tuning search space.') + 'tuning_search_space', + 'prize_qualification_baselines/external_tuning/tuning_search_space.json', + 'Path to tuning search space.', +) flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') flags.DEFINE_boolean( - 'dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') + 'dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ', +) flags.DEFINE_enum( - 'tuning_ruleset', - 'external', - enum_values=['external', 'self'], - help='Can be either external of self.') + 'tuning_ruleset', + 'external', + enum_values=['external', 'self'], + help='Can be either external of self.', +) flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') -flags.DEFINE_integer('hparam_start_index', - None, - 'Start index for tuning trials.') +flags.DEFINE_integer( + 'hparam_start_index', None, 'Start index for tuning trials.' +) flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.') -flags.DEFINE_integer('submission_id', - 0, - 'Submission ID to generate study and hparam seeds.') -flags.DEFINE_string('held_out_workloads_config_path', - None, - 'Path to config containing held-out workloads') +flags.DEFINE_integer( + 'submission_id', 0, 'Submission ID to generate study and hparam seeds.' +) flags.DEFINE_string( - 'workload_metadata_path', - None, - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'We provide workload_metadata_external_tuning.json and ' - 'workload_metadata_self_tuning.json as references.' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') + 'held_out_workloads_config_path', + None, + 'Path to config containing held-out workloads', +) flags.DEFINE_string( - 'workloads', - None, - 'String representing a comma separated list of workload names.' - 'If not None, only run this workload, else run all workloads in workload_metadata_path.' + 'workload_metadata_path', + None, + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'We provide workload_metadata_external_tuning.json and ' + 'workload_metadata_self_tuning.json as references.' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.', +) +flags.DEFINE_string( + 'workloads', + None, + 'String representing a comma separated list of workload names.' + 'If not None, only run this workload, else run all workloads in workload_metadata_path.', +) +flags.DEFINE_string( + 'additional_requirements_path', None, 'Path to requirements.txt if any.' ) -flags.DEFINE_string('additional_requirements_path', - None, - 'Path to requirements.txt if any.') flags.DEFINE_integer( - 'max_steps', - None, - 'Maximum number of steps to run. Must set flag enable_step_budget.' - 'This flag takes precedence over the run_percentage flag.') + 'max_steps', + None, + 'Maximum number of steps to run. Must set flag enable_step_budget.' + 'This flag takes precedence over the run_percentage flag.', +) flags.DEFINE_bool( - 'enable_step_budget', - False, - 'Flag that has to be explicitly set to override time budgets to step budget percentage.' + 'enable_step_budget', + False, + 'Flag that has to be explicitly set to override time budgets to step budget percentage.', ) FLAGS = flags.FLAGS def read_held_out_workloads(filename): - with open(filename, "r") as f: + with open(filename, 'r') as f: held_out_workloads = json.load(f) return held_out_workloads @@ -132,11 +142,13 @@ def kill_containers(): def gpu_is_active(): - output = subprocess.check_output([ + output = subprocess.check_output( + [ 'nvidia-smi', '--query-gpu=utilization.gpu', - '--format=csv,noheader,nounits' - ]) + '--format=csv,noheader,nounits', + ] + ) return any(int(x) > 0 for x in output.decode().splitlines()) @@ -151,7 +163,8 @@ def wait_until_container_not_running(sleep_interval=5 * 60): gpu_last_active = datetime.datetime.now().timestamp() if (datetime.datetime.now().timestamp() - gpu_last_active) > 45 * 60: kill_containers( - "Killing container: GPUs have been inactive > 45 minutes...") + 'Killing container: GPUs have been inactive > 45 minutes...' + ) time.sleep(sleep_interval) return @@ -167,7 +180,9 @@ def main(_): hparam_start_index_flag = '' hparam_end_index_flag = '' if FLAGS.hparam_start_index: - hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' + hparam_start_index_flag = ( + f'--hparam_start_index {FLAGS.hparam_start_index} ' + ) if FLAGS.hparam_end_index: hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} ' study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 @@ -178,7 +193,9 @@ def main(_): additional_requirements_path_flag = '' if FLAGS.additional_requirements_path: - additional_requirements_path_flag = f'--additional_requirements_path {FLAGS.additional_requirements_path} ' + additional_requirements_path_flag = ( + f'--additional_requirements_path {FLAGS.additional_requirements_path} ' + ) submission_id = FLAGS.submission_id @@ -188,7 +205,7 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))) + rng_key = prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id)) with open(FLAGS.workload_metadata_path) as f: workload_metadata = json.load(f) @@ -199,20 +216,24 @@ def main(_): # Read heldout workloads if FLAGS.held_out_workloads_config_path: held_out_workloads = read_held_out_workloads( - FLAGS.held_out_workloads_config_path) + FLAGS.held_out_workloads_config_path + ) workloads = workloads + held_out_workloads # Filter workloads if explicit workloads specified if FLAGS.workloads is not None: workloads = list( - filter(lambda x: x in FLAGS.workloads.split(','), workloads)) + filter(lambda x: x in FLAGS.workloads.split(','), workloads) + ) if len(workloads) != len(FLAGS.workloads.split(',')): unmatched_workloads = set(FLAGS.workloads.split(',')) - set(workloads) raise ValueError(f'Invalid workload name {unmatched_workloads}') rng_subkeys = prng.split(rng_key, num_studies) - for study_index, rng_subkey in zip(range(study_start_index, study_end_index + 1), rng_subkeys): + for study_index, rng_subkey in zip( + range(study_start_index, study_end_index + 1), rng_subkeys + ): print('-' * 100) print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40) print('-' * 100) @@ -225,40 +246,46 @@ def main(_): base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system( - "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'" + ) # clear caches print('=' * 100) dataset = workload_metadata[base_workload_name]['dataset'] max_steps_flag = '' if FLAGS.enable_step_budget: - run_fraction = FLAGS.run_percentage / 100. + run_fraction = FLAGS.run_percentage / 100.0 if FLAGS.max_steps is None: - max_steps = int(workload_metadata[base_workload_name]['max_steps'] * - run_fraction) + max_steps = int( + workload_metadata[base_workload_name]['max_steps'] * run_fraction + ) else: max_steps = FLAGS.max_steps max_steps_flag = f'-m {max_steps}' mount_repo_flag = '' if FLAGS.local: - mount_repo_flag = '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v /home/kasimbeg/data/:/data/ ' - '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' - '-v /home/kasimbeg/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-e {study_dir} ' - f'{max_steps_flag} ' - f'--num_tuning_trials {num_tuning_trials} ' - f'--rng_seed {run_seed} ' - f'{additional_requirements_path_flag}' - '-c false ' - '-o true ' - '-i true ') + mount_repo_flag = ( + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' + ) + command = ( + 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' + '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' + '-v /home/kasimbeg/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-e {study_dir} ' + f'{max_steps_flag} ' + f'--num_tuning_trials {num_tuning_trials} ' + f'--rng_seed {run_seed} ' + f'{additional_requirements_path_flag}' + '-c false ' + '-o true ' + '-i true ' + ) # Append tuning ruleset flags tuning_ruleset_flags = '' @@ -280,18 +307,19 @@ def main(_): return_code = 0 if return_code == 0: print( - f'SUCCESS: container for {framework} {workload} launched successfully' + f'SUCCESS: container for {framework} {workload} launched successfully' ) print(f'Command: {command}') print(f'Results will be logged to {experiment_name}') else: print( - f'Failed: container for {framework} {workload} failed with exit code {return_code}.' + f'Failed: container for {framework} {workload} failed with exit code {return_code}.' ) print(f'Command: {command}') wait_until_container_not_running() os.system( - "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'" + ) # clear caches print('=' * 100) diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md index ffd56fbf3..a8e41f04b 100644 --- a/scoring/utils/slurm/README.md +++ b/scoring/utils/slurm/README.md @@ -1,26 +1,63 @@ +# Launching SLURM jobs with SBATCH + This folder contains a SLURM batch script that can be used to run jobs where each job corresponds to a training run on a given workload, training algorithm, random seed and tuning trial (if on external tuning ruleset). To launch jobs: -1) Generate a job config. The following command will generate a config.json. -``` + +1. Generate a job config. The following command will generate a config.json. + +```bash python3 make_job_config.py \ --submission_path \ --tuning_search_space \ --experiment_dir $HOME/experiments/ \ --framework ``` -2) Save the config.json in the same directory you will run the sbatch script from. -3) Check the sbatch script `run_jobs.sh`. + +2. Save the config.json in the same directory you will run the sbatch script from. +3. Copy the example sbatch script `run_jobs.sh`. + - Set the task range to the number of tasks in the config. + ``` #SBATCH --array=0-119 ``` + - Set the output and error logs directory for the SLURM logs. + ``` #SBATCH --output=experiments///job_%A_%a.out #SBATCH --error=experiments///job_%A_%a.err ``` -4) Submit a SLURM batch job by running: + +- Update the gcp project information, docker image, config file path and bucket to save the logs to as necessary: + +``` +REPO="us-central1-docker.pkg.dev" +IMAGE="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_main" +y | gcloud auth configure-docker $REPO +docker pull $IMAGE +# Job config (ATTENTION: you may want to modify this) +config_file="$HOME/configs/pmap_job_config.json" # Replace with your config file path +LOGS_BUCKET="algoperf-runs-internal" +``` + +4. Submit a SLURM batch job by running: + ``` sbatch run_jobs.sh ``` + +# Set up new SLURM cluster + +If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster). +To set up the new cluster: + +1. [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart). +2. Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml). +3. Manually update the image: + 1. Create a VM from the Disk image created in the previous step. + 2. Install the NVIDIA container toolkit on the VM. + 3. Transfer the data from GCP bucket to `/opt/data`. + 4. Create a new disk image from the VM. +4. Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml). diff --git a/scoring/utils/slurm/algoperf_slurm_cluster.yaml b/scoring/utils/slurm/algoperf_slurm_cluster.yaml new file mode 100644 index 000000000..073fe98cc --- /dev/null +++ b/scoring/utils/slurm/algoperf_slurm_cluster.yaml @@ -0,0 +1,105 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +blueprint_name: algoperf-slurm-internal + +vars: + project_id: training-algorithms-external + deployment_name: algoperf-slurm-internal + region: europe-west4 + zone: europe-west4-a + disk_size_gb: 3000 + slurm_cluster_name: algoperf + image_name: algoperf-image-data-container-tkt + +# Recommended to use GCS backend for Terraform state +# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state +# +# terraform_backend_defaults: +# type: gcs +# configuration: +# bucket: <> + +deployment_groups: + - group: primary + modules: + - id: network + source: modules/network/vpc + + - id: homefs + source: community/modules/file-system/nfs-server + use: [network] + settings: + local_mounts: [/home] + disk_size: 3000 + zone: $(vars.zone) + + - id: script + source: modules/scripts/startup-script + settings: + + - group: cluster + modules: + - id: v100_nodeset + source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset + use: + - network + settings: + node_count_dynamic_max: 25 # set to 0 if you want node to live forever + region: $(vars.region) + zone: $(vars.zone) + enable_placement: false + bandwidth_tier: gvnic_enabled + machine_type: n1-standard-64 + guest_accelerator: + - type: nvidia-tesla-v100 + count: 8 + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + + - id: v100_partition + source: community/modules/compute/schedmd-slurm-gcp-v6-partition + use: [v100_nodeset] + settings: + exclusive: false + partition_name: v100 + is_default: true + + - id: slurm_login + source: community/modules/scheduler/schedmd-slurm-gcp-v6-login + use: [network] + settings: + enable_login_public_ips: true + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + zone: $(vars.zone) + + - id: slurm_controller + source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller + use: + - network + - v100_partition + - homefs + - slurm_login + settings: + enable_controller_public_ips: true + instance_image: + project: $(vars.project_id) + name: $(vars.image_name) + instance_image_custom: true + region: $(vars.region) diff --git a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml new file mode 100644 index 000000000..f3b5be5dd --- /dev/null +++ b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml @@ -0,0 +1,131 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--- +blueprint_name: algoperf-slurm-packer + +vars: + project_id: training-algorithms-external + deployment_name: algoperf-slurm-packer + region: europe-west4 + zone: europe-west4-a + new_image: + family: algoperf-image + project: $(vars.project_id) + disk_size_gb: 3000 + slurm_cluster_name: algoperf-packer + +# Recommended to use GCS backend for Terraform state +# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state +# +# terraform_backend_defaults: +# type: gcs +# configuration: +# bucket: <> + +deployment_groups: + - group: primary + modules: + - id: network + source: modules/network/vpc + + - id: script + source: modules/scripts/startup-script + settings: + region: $(vars.region) + install_ansible: true + docker: + enabled: true + world_writable: true + # (TODO) Do I need this? + configure_ssh_host_patterns: + - 10.0.0.* + - 10.1.0.* + - 10.2.0.* + - 10.3.0.* + - 10.4.0.* + - 10.5.0.* + - 10.6.0.* + - 10.7.0.* + - $(vars.slurm_cluster_name)* + runners: + - type: shell + destination: install-ml-libraries.sh + content: | + #!/bin/bash + # this script is designed to execute on Slurm images published by SchedMD that: + # - are based on Debian distribution of Linux + # - have NVIDIA drivers pre-installed + + set -e -o pipefail + + echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list + apt-get update --allow-releaseinfo-change + apt-get install --assume-yes google-fast-socket + + CONDA_BASE=/opt/conda + + if [ -d $CONDA_BASE ]; then + exit 0 + fi + + DL_DIR=\$(mktemp -d) + cd $DL_DIR + curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh + HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE + cd - + rm -rf $DL_DIR + unset DL_DIR + + source $CONDA_BASE/bin/activate base + conda init --system + conda config --system --set auto_activate_base False + # following channel ordering is important! use strict_priority! + conda config --system --set channel_priority strict + conda update -n base conda --yes + + ### create a virtual environment for tensorflow + conda create -n tf python=3.11 --yes + conda activate tf + pip install tensorflow[and-cuda]==2.18.* + pip install tensorrt==10.6.* + + ### create a virtual environment for pytorch + conda create -n pytorch python=3.11 --yes + conda activate pytorch + pip install torch torchvision torchaudio + + - group: packer + modules: + - id: custom-image + source: modules/packer/custom-image + kind: packer + use: + - network + - script + settings: + # give VM a public IP to ensure startup script can reach public internet + # w/o new VPC + omit_external_ip: false + source_image_project_id: [schedmd-slurm-public] + # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family + source_image_family: slurm-gcp-6-8-debian-11 + # You can find size of source image by using following command + # gcloud compute images describe-from-family --project schedmd-slurm-public + disk_size: $(vars.disk_size_gb) + image_family: $(vars.new_image.family) + # building this image does not require a GPU-enabled VM + machine_type: c2-standard-16 + state_timeout: 300m + zone: $(vars.zone) diff --git a/scoring/utils/slurm/config.json b/scoring/utils/slurm/config.json new file mode 100644 index 000000000..cb49f9bf4 --- /dev/null +++ b/scoring/utils/slurm/config.json @@ -0,0 +1,106 @@ +{ + "0": { + "framework": "jax", + "workload": "imagenet_resnet", + "dataset": "imagenet", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 411096763, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "1": { + "framework": "jax", + "workload": "imagenet_vit", + "dataset": "imagenet", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -1884713130, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "2": { + "framework": "jax", + "workload": "fastmri", + "dataset": "fastmri", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -214785144, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "3": { + "framework": "jax", + "workload": "ogbg", + "dataset": "ogbg", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -893097833, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "4": { + "framework": "jax", + "workload": "wmt", + "dataset": "wmt", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -1244182279, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "5": { + "framework": "jax", + "workload": "librispeech_deepspeech", + "dataset": "librispeech", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 1546003634, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "6": { + "framework": "jax", + "workload": "criteo1tb", + "dataset": "criteo1tb", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": -2062333143, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + }, + "7": { + "framework": "jax", + "workload": "librispeech_conformer", + "dataset": "librispeech", + "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py", + "experiment_dir": "$HOME/experiments/jit_switch/study_0", + "rng_seed": 409209730, + "tuning_ruleset": "external", + "num_tuning_trials": 1, + "hparam_start_index": 0, + "hparam_end_index": 1, + "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json" + } +} diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py index 116e70459..f6a1ca158 100644 --- a/scoring/utils/slurm/make_job_config.py +++ b/scoring/utils/slurm/make_job_config.py @@ -6,60 +6,66 @@ --experiment_dir $HOME/experiments/ \ --framework """ + import json import os -from absl import app -from absl import flags import jax +from absl import app, flags SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py' -EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' +EXPERIMENT_DIR = ( + 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' +) TUNING_SEARCH_SPACE = None FRAMEWORK = 'pytorch' TUNING_RULESET = 'self' flags.DEFINE_string( - 'submission_path', - SUBMISSION_PATH, - 'Path to submission module relative to algorithmic-efficiency dir.') + 'submission_path', + SUBMISSION_PATH, + 'Path to submission module relative to algorithmic-efficiency dir.', +) +flags.DEFINE_string( + 'tuning_search_space', + TUNING_SEARCH_SPACE, + 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.', +) flags.DEFINE_string( - 'tuning_search_space', - TUNING_SEARCH_SPACE, - 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.' + 'experiment_dir', + EXPERIMENT_DIR, + 'Path to experiment dir where logs will be saved.', ) -flags.DEFINE_string('experiment_dir', - EXPERIMENT_DIR, - 'Path to experiment dir where logs will be saved.') flags.DEFINE_enum( - 'framework', - FRAMEWORK, - enum_values=['jax', 'pytorch'], - help='Can be either pytorch or jax.') + 'framework', + FRAMEWORK, + enum_values=['jax', 'pytorch'], + help='Can be either pytorch or jax.', +) flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.') flags.DEFINE_enum( - 'tuning_ruleset', - TUNING_RULESET, - enum_values=['external', 'self'], - help='Which tuning ruleset to score this submission on. Can be external or self.' + 'tuning_ruleset', + TUNING_RULESET, + enum_values=['external', 'self'], + help='Which tuning ruleset to score this submission on. Can be external or self.', ) FLAGS = flags.FLAGS -MIN_INT = -2**(31) -MAX_INT = 2**(31) - 1 +MIN_INT = -(2 ** (31)) +MAX_INT = 2 ** (31) - 1 NUM_TUNING_TRIALS = 5 # For external tuning ruleset NUM_STUDIES = 3 WORKLOADS = { - "imagenet_resnet": {"dataset": "imagenet"}, - "imagenet_vit": {"dataset": "imagenet"}, - "fastmri": {"dataset": "fastmri"}, - "ogbg": {"dataset": "ogbg"}, - "wmt": {"dataset": "wmt"}, - "librispeech_deepspeech": {"dataset": "librispeech"}, - "criteo1tb": {"dataset": "criteo1tb"}, - "librispeech_conformer": {"dataset": "librispeech"} + 'imagenet_resnet': {'dataset': 'imagenet'}, + 'imagenet_vit': {'dataset': 'imagenet'}, + 'fastmri': {'dataset': 'fastmri'}, + 'ogbg': {'dataset': 'ogbg'}, + 'wmt': {'dataset': 'wmt'}, + 'librispeech_deepspeech': {'dataset': 'librispeech'}, + 'criteo1tb': {'dataset': 'criteo1tb'}, + 'librispeech_conformer': {'dataset': 'librispeech'}, } @@ -81,7 +87,7 @@ def main(_): print(seed) # Add job job = {} - study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}") + study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}') job['framework'] = FLAGS.framework job['workload'] = workload job['dataset'] = WORKLOADS[workload]['dataset'] @@ -103,7 +109,7 @@ def main(_): print(seed) # Add job job = {} - study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}") + study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}') job['framework'] = FLAGS.framework job['workload'] = workload job['dataset'] = WORKLOADS[workload]['dataset'] @@ -119,7 +125,7 @@ def main(_): # Convert job array to dict with job indices job_dict = {} for i, job in enumerate(jobs): - job_dict[f"{i}"] = job + job_dict[f'{i}'] = job with open('config.json', 'w') as f: json.dump(job_dict, f, indent=4) diff --git a/scoring/utils/slurm/run_jobs.sh b/scoring/utils/slurm/run_jobs.sh new file mode 100644 index 000000000..5fcf8f69e --- /dev/null +++ b/scoring/utils/slurm/run_jobs.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +#SBATCH --nodes=1 # give it a full node +#SBATCH --ntasks-per-node=1 +#SBATCH --array= +#SBATCH --partition=v100 +#SBATCH --gpus-per-node=8 +#SBATCH --exclusive #this will not allow other jobs to run on this cluster +#SBATCH --output=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.out +#SBATCH --error=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.err + +# Usage: sbatch .sh +# This script reads config.json and launches a sbatch job using task +# arrays where each job in the array corresponds to a training run +# for a workload given a random seed and tuning trial index. +# To generate the config.json use make_job_config.py. + +set -x + +# Pull docker image (ATTENTION: you may want to modify this) +REPO="" +IMAGE="" +y | gcloud auth configure-docker $REPO +docker pull $IMAGE +# Job config (ATTENTION: you may want to modify this) +config_file="" # Replace with your config file path +LOGS_BUCKET="" # replace with your bucket used for logging + + +# Function to read a JSON file and extract a value by key +read_json_value() { + local json_file="$1" + local index="$2" + local key="$3" + local value=$(jq -r ".[\"$index\"].$key" "$json_file") + echo "$value" +} + +# Check if jq is installed +if ! command -v jq &> /dev/null +then + echo "jq could not be found. Please install it." + exit 1 +fi + +TASK="$SLURM_ARRAY_TASK_ID" +FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework") +DATASET=$(read_json_value "$config_file" "$TASK" "dataset") +SUBMISSION_PATH=$(read_json_value "$config_file" "$TASK" "submission_path") +FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework") +TUNING_SEARCH_SPACE=$(read_json_value "$config_file" "$TASK" "tuning_search_space") +EXPERIMENT_DIR=$(read_json_value "$config_file" "$TASK" "experiment_dir") +MAX_STEPS=$(read_json_value "$config_file" "$TASK" "max_steps") +RNG_SEED=$(read_json_value "$config_file" "$TASK" "rng_seed") +WORKLOAD=$(read_json_value "$config_file" "$TASK" "workload") +HPARAM_START_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_start_index") +HPARAM_END_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_end_index") +NUM_TUNING_TRIALS=$(read_json_value "$config_file" "$TASK" "num_tuning_trials") +TUNING_RULESET=$(read_json_value "$config_file" "$TASK" "tuning_ruleset") +MAX_GLOBAL_STEPS=$(read_json_value "$config_file" "$MAX_GLOBAL_STEPS" "max_global_steps") + +docker run \ + -v /opt/data/:/data/ \ + -v $HOME/submissions_algorithms/:/algorithmic-efficiency/submissions_algorithms \ + --gpus all \ + --ipc=host \ + $IMAGE \ + -d $DATASET \ + -f $FRAMEWORK \ + -s $SUBMISSION_PATH \ + -w $WORKLOAD \ + -t $TUNING_SEARCH_SPACE \ + -e $EXPERIMENT_DIR \ + -c False \ + -o True \ + --rng_seed $RNG_SEED \ + --hparam_start_index $HPARAM_START_INDEX \ + --hparam_end_index $HPARAM_END_INDEX \ + --num_tuning_trials $NUM_TUNING_TRIALS \ + --tuning_ruleset $TUNING_RULESET \ + --logs_bucket $LOGS_BUCKET \ + -i true \ + -r false \ No newline at end of file diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c205d28b2..c7d4ae195 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -1,34 +1,34 @@ { - "imagenet_resnet": { - "max_steps": 186666, - "dataset": "imagenet" - }, - "imagenet_vit": { - "max_steps": 186666, - "dataset": "imagenet" - }, - "fastmri": { - "max_steps": 36189, - "dataset": "fastmri" - }, - "ogbg": { - "max_steps": 80000, - "dataset": "ogbg" - }, - "wmt": { - "max_steps": 133333, - "dataset": "wmt" - }, - "librispeech_deepspeech": { - "max_steps": 48000, - "dataset": "librispeech" - }, - "criteo1tb": { - "max_steps": 10666, - "dataset": "criteo1tb" - }, - "librispeech_conformer": { - "max_steps": 80000, - "dataset": "librispeech" - } - } \ No newline at end of file + "imagenet_resnet": { + "max_steps": 186666, + "dataset": "imagenet" + }, + "imagenet_vit": { + "max_steps": 186666, + "dataset": "imagenet" + }, + "fastmri": { + "max_steps": 36189, + "dataset": "fastmri" + }, + "ogbg": { + "max_steps": 80000, + "dataset": "ogbg" + }, + "wmt": { + "max_steps": 133333, + "dataset": "wmt" + }, + "librispeech_deepspeech": { + "max_steps": 48000, + "dataset": "librispeech" + }, + "criteo1tb": { + "max_steps": 10666, + "dataset": "criteo1tb" + }, + "librispeech_conformer": { + "max_steps": 80000, + "dataset": "librispeech" + } +} diff --git a/scoring/utils/workload_metadata_self_tuning.json b/scoring/utils/workload_metadata_self_tuning.json index 105d5c52f..9d3e6b93d 100644 --- a/scoring/utils/workload_metadata_self_tuning.json +++ b/scoring/utils/workload_metadata_self_tuning.json @@ -1,34 +1,34 @@ { - "imagenet_resnet": { - "max_steps": 559998, - "dataset": "imagenet" - }, - "imagenet_vit": { - "max_steps": 559998, - "dataset": "imagenet" - }, - "fastmri": { - "max_steps": 108567, - "dataset": "fastmri" - }, - "ogbg": { - "max_steps": 240000, - "dataset": "ogbg" - }, - "wmt": { - "max_steps": 399999, - "dataset": "wmt" - }, - "librispeech_deepspeech": { - "max_steps": 144000, - "dataset": "librispeech" - }, - "criteo1tb": { - "max_steps": 31998, - "dataset": "criteo1tb" - }, - "librispeech_conformer": { - "max_steps": 240000, - "dataset": "librispeech" - } - } \ No newline at end of file + "imagenet_resnet": { + "max_steps": 559998, + "dataset": "imagenet" + }, + "imagenet_vit": { + "max_steps": 559998, + "dataset": "imagenet" + }, + "fastmri": { + "max_steps": 108567, + "dataset": "fastmri" + }, + "ogbg": { + "max_steps": 240000, + "dataset": "ogbg" + }, + "wmt": { + "max_steps": 399999, + "dataset": "wmt" + }, + "librispeech_deepspeech": { + "max_steps": 144000, + "dataset": "librispeech" + }, + "criteo1tb": { + "max_steps": 31998, + "dataset": "criteo1tb" + }, + "librispeech_conformer": { + "max_steps": 240000, + "dataset": "librispeech" + } +} diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..8dc8589cb 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,41 +17,37 @@ import datetime import gc import importlib -from inspect import signature import itertools import json import os import struct import time +from inspect import signature from types import MappingProxyType from typing import Any, Dict, Optional, Tuple -from absl import app -from absl import flags -from absl import logging import jax import tensorflow as tf import torch import torch.distributed as dist +from absl import app, flags, logging # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') -from algoperf import checkpoint_utils -from algoperf import halton -from algoperf import logger_utils -from algoperf import random_utils as prng -from algoperf import spec -from algoperf.profiler import PassThroughProfiler -from algoperf.profiler import Profiler -from algoperf.pytorch_utils import pytorch_init -from algoperf.pytorch_utils import pytorch_setup -from algoperf.pytorch_utils import sync_ddp_time -from algoperf.workloads import workloads +from algoperf import checkpoint_utils, halton, logger_utils, spec # noqa: E402 +from algoperf import random_utils as prng # noqa: E402 +from algoperf.profiler import PassThroughProfiler, Profiler # noqa: E402 +from algoperf.pytorch_utils import ( # noqa: E402 + pytorch_init, + pytorch_setup, + sync_ddp_time, +) +from algoperf.workloads import workloads # noqa: E402 # Environment variables -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings. # disable only for deepspeech if it works fine for other workloads os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' @@ -62,106 +58,121 @@ WORKLOADS = workloads.WORKLOADS flags.DEFINE_string( - 'submission_path', - None, - 'The relative path of the Python file containing submission functions. ' - 'NOTE: the submission dir must have an __init__.py file!') + 'submission_path', + None, + 'The relative path of the Python file containing submission functions. ' + 'NOTE: the submission dir must have an __init__.py file!', +) flags.DEFINE_string( - 'workload', - None, - help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}' + 'workload', + None, + help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}', ) flags.DEFINE_enum( - 'tuning_ruleset', - 'external', - enum_values=['external', 'self'], - help='Which tuning ruleset to use.') + 'tuning_ruleset', + 'external', + enum_values=['external', 'self'], + help='Which tuning ruleset to use.', +) flags.DEFINE_string( - 'tuning_search_space', - None, - 'The path to the JSON file describing the external tuning search space.') -flags.DEFINE_integer('num_tuning_trials', - 1, - 'The number of external hyperparameter trials to run.') + 'tuning_search_space', + None, + 'The path to the JSON file describing the external tuning search space.', +) +flags.DEFINE_integer( + 'num_tuning_trials', 1, 'The number of external hyperparameter trials to run.' +) flags.DEFINE_string('data_dir', '~/data', 'Dataset location.') -flags.DEFINE_string('imagenet_v2_data_dir', - None, - 'Dataset location for ImageNet-v2.') -flags.DEFINE_string('librispeech_tokenizer_vocab_path', - '', - 'Location to librispeech tokenizer.') +flags.DEFINE_string( + 'imagenet_v2_data_dir', None, 'Dataset location for ImageNet-v2.' +) +flags.DEFINE_string( + 'librispeech_tokenizer_vocab_path', '', 'Location to librispeech tokenizer.' +) flags.DEFINE_enum( - 'framework', - None, - enum_values=['jax', 'pytorch'], - help='Whether to use Jax or Pytorch for the submission. Controls among ' - 'other things if the Jax or Numpy RNG library is used for RNG.') + 'framework', + None, + enum_values=['jax', 'pytorch'], + help='Whether to use Jax or Pytorch for the submission. Controls among ' + 'other things if the Jax or Numpy RNG library is used for RNG.', +) flags.DEFINE_boolean( - 'torch_compile', - True, - 'Whether to use `torch.compile` to JIT-compile PyTorch code. ' - 'This will only take effect when `framework`==pytorch.') + 'torch_compile', + True, + 'Whether to use `torch.compile` to JIT-compile PyTorch code. ' + 'This will only take effect when `framework`==pytorch.', +) flags.DEFINE_string( - 'experiment_dir', - None, - 'The root directory to store all experiments. ' - 'It is required and the directory should have ' - 'an absolute path rather than a relative path.') + 'experiment_dir', + None, + 'The root directory to store all experiments. ' + 'It is required and the directory should have ' + 'an absolute path rather than a relative path.', +) flags.DEFINE_string('experiment_name', None, 'Name of the experiment.') flags.DEFINE_boolean( - 'save_checkpoints', - True, - 'Whether or not to save checkpoints of the model and optimizer ' - 'at every eval and after training.') + 'save_checkpoints', + True, + 'Whether or not to save checkpoints of the model and optimizer ' + 'at every eval and after training.', +) +flags.DEFINE_boolean( + 'save_intermediate_checkpoints', + True, + 'Whether to save any intermediate checkpoints. ' + 'If False, it will only keep the latest checkpoint.', +) flags.DEFINE_boolean( - 'save_intermediate_checkpoints', - True, - 'Whether to save any intermediate checkpoints. ' - 'If False, it will only keep the latest checkpoint.') -flags.DEFINE_boolean('resume_last_run', - None, - 'Whether to resume the experiment from its last run.') + 'resume_last_run', None, 'Whether to resume the experiment from its last run.' +) flags.DEFINE_boolean( - 'append_timestamp', - False, - 'If True, the current datetime will be appended to the experiment name. ' - 'Useful for guaranteeing a unique experiment dir for new runs.') -flags.DEFINE_boolean('use_wandb', - False, - 'Whether to use Weights & Biases logging.') + 'append_timestamp', + False, + 'If True, the current datetime will be appended to the experiment name. ' + 'Useful for guaranteeing a unique experiment dir for new runs.', +) +flags.DEFINE_boolean( + 'use_wandb', False, 'Whether to use Weights & Biases logging.' +) flags.DEFINE_boolean('profile', False, 'Whether to produce profiling output.') -flags.DEFINE_integer('max_global_steps', - None, - 'Maximum number of update steps.') +flags.DEFINE_integer( + 'max_global_steps', None, 'Maximum number of update steps.' +) flags.DEFINE_boolean( - 'overwrite', - False, - 'Whether to overwrite the experiment with identical experiment_dir and' - 'experiment_name.') + 'overwrite', + False, + 'Whether to overwrite the experiment with identical experiment_dir and' + 'experiment_name.', +) flags.DEFINE_integer( - 'hparam_start_index', - None, - 'Start index to slice set of hyperparameters in tuning search space.') + 'hparam_start_index', + None, + 'Start index to slice set of hyperparameters in tuning search space.', +) flags.DEFINE_integer( - 'hparam_end_index', - None, - 'End index to slice set of hyperparameters in tuning search space.') + 'hparam_end_index', + None, + 'End index to slice set of hyperparameters in tuning search space.', +) flags.DEFINE_integer( - 'rng_seed', - None, - 'Value of rng seed. If None, a random seed will' - 'be generated from hardware.') -flags.DEFINE_boolean('set_pytorch_max_split_size', - False, - 'If true, set pytorch max_split_size_mb to 256') + 'rng_seed', + None, + 'Value of rng seed. If None, a random seed willbe generated from hardware.', +) +flags.DEFINE_boolean( + 'set_pytorch_max_split_size', + False, + 'If true, set pytorch max_split_size_mb to 256', +) flags.DEFINE_integer( - 'pytorch_eval_num_workers', - 0, - 'Number of workers for ImageNet PyTorch evaluation data loaders.' - 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' - 'in incorrect evals currently, see issues/732.') + 'pytorch_eval_num_workers', + 0, + 'Number of workers for ImageNet PyTorch evaluation data loaders.' + 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' + 'in incorrect evals currently, see issues/732.', +) FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -193,23 +204,23 @@ def _reset_cuda_mem(): def train_once( - workload: spec.Workload, - workload_name: str, - global_batch_size: int, - global_eval_batch_size: int, - data_dir: str, - imagenet_v2_data_dir: str, - init_optimizer_state: spec.InitOptimizerFn, - update_params: spec.UpdateParamsFn, - data_selection: spec.DataSelectionFn, - prepare_for_eval: Optional[spec.PrepareForEvalFn], - hyperparameters: Optional[spec.Hyperparameters], - rng_seed: int, - rng: spec.RandomState, - profiler: Profiler, - max_global_steps: int = None, - log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True + workload: spec.Workload, + workload_name: str, + global_batch_size: int, + global_eval_batch_size: int, + data_dir: str, + imagenet_v2_data_dir: str, + init_optimizer_state: spec.InitOptimizerFn, + update_params: spec.UpdateParamsFn, + data_selection: spec.DataSelectionFn, + prepare_for_eval: Optional[spec.PrepareForEvalFn], + hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, + rng: spec.RandomState, + profiler: Profiler, + max_global_steps: int = None, + log_dir: Optional[str] = None, + save_checkpoints: Optional[bool] = True, ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) @@ -222,47 +233,44 @@ def train_once( workload.eval_num_workers = FLAGS.pytorch_eval_num_workers with profiler.profile('Initializing dataset'): input_queue = workload._build_input_queue( - data_rng, - 'train', - data_dir=data_dir, - global_batch_size=global_batch_size) + data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size + ) logging.info('Initializing model.') with profiler.profile('Initializing model'): - dropout_rate = None - aux_dropout_rate = None - if hasattr(hyperparameters, 'dropout_rate'): - dropout_rate = hyperparameters.dropout_rate - if hasattr(hyperparameters, 'aux_dropout_rate'): - aux_dropout_rate = hyperparameters.aux_dropout_rate - model_params, model_state = workload.init_model_fn( - model_init_rng, dropout_rate, aux_dropout_rate) + model_params, model_state = workload.init_model_fn(model_init_rng) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ - 'librispeech_conformer', - 'ogbg', - 'criteo1tb', - 'imagenet_vit', - 'librispeech_deepspeech' + 'librispeech_conformer', + 'ogbg', + 'criteo1tb', + 'imagenet_vit', + 'librispeech_deepspeech', ] eager_backend_workloads = [] aot_eager_backend_workloads = [] loss_compilation_workloads = [ - 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt' + 'fastmri', + 'librispeech_deepspeech', + 'ogbg', + 'wmt', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: logging.warning( - 'These workloads cannot be fully compiled under current ' - 'PyTorch version. Proceeding without `torch.compile`.') + 'These workloads cannot be fully compiled under current ' + 'PyTorch version. Proceeding without `torch.compile`.' + ) elif base_workload in eager_backend_workloads: logging.warning( - 'These workloads cannot be fully compiled under current ' - 'PyTorch version. Proceeding with `backend=eager`.') + 'These workloads cannot be fully compiled under current ' + 'PyTorch version. Proceeding with `backend=eager`.' + ) model_params = torch.compile(model_params, backend='eager') elif base_workload in aot_eager_backend_workloads: logging.warning( - 'These workloads cannot be fully compiled under current ' - 'PyTorch version. Proceeding with `backend=aot_eager`.') + 'These workloads cannot be fully compiled under current ' + 'PyTorch version. Proceeding with `backend=aot_eager`.' + ) model_params = torch.compile(model_params, backend='aot_eager') else: logging.info('Performing `torch.compile`.') @@ -271,11 +279,9 @@ def train_once( workload.loss_fn = torch.compile(workload.loss_fn) logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): - optimizer_state = init_optimizer_state(workload, - model_params, - model_state, - hyperparameters, - opt_init_rng) + optimizer_state = init_optimizer_state( + workload, model_params, model_state, hyperparameters, opt_init_rng + ) logging.info('Initializing metrics bundle.') # Check if 'train_state' is in the function signature @@ -283,15 +289,15 @@ def train_once( # Bookkeeping. train_state = { - 'validation_goal_reached': False, - 'test_goal_reached': False, - 'is_time_remaining': True, - 'last_eval_time': 0, - 'training_complete': False, - 'accumulated_submission_time': 0, - 'accumulated_eval_time': 0, - 'accumulated_logging_time': 0, - 'last_step_end_time': None, + 'validation_goal_reached': False, + 'test_goal_reached': False, + 'is_time_remaining': True, + 'last_eval_time': 0, + 'training_complete': False, + 'accumulated_submission_time': 0, + 'accumulated_eval_time': 0, + 'accumulated_logging_time': 0, + 'last_step_end_time': None, } global_step = 0 eval_results = [] @@ -301,22 +307,25 @@ def train_once( logging.info('Initializing checkpoint and logger.') if log_dir is not None: # If the checkpoint exists, load from the checkpoint. - (optimizer_state, - model_params, - model_state, - train_state, - eval_results, - global_step, - preemption_count) = checkpoint_utils.maybe_restore_checkpoint( - FLAGS.framework, - optimizer_state, - model_params, - model_state, - train_state, - eval_results, - global_step, - preemption_count, - checkpoint_dir=log_dir) + ( + optimizer_state, + model_params, + model_state, + train_state, + eval_results, + global_step, + preemption_count, + ) = checkpoint_utils.maybe_restore_checkpoint( + FLAGS.framework, + optimizer_state, + model_params, + model_state, + train_state, + eval_results, + global_step, + preemption_count, + checkpoint_dir=log_dir, + ) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') meta_data = logger_utils.get_meta_data(workload, rng_seed) @@ -326,9 +335,9 @@ def train_once( logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) metrics_logger = None if RANK == 0: - metrics_logger = logger_utils.set_up_loggers(log_dir, - flags.FLAGS, - hyperparameters) + metrics_logger = logger_utils.set_up_loggers( + log_dir, flags.FLAGS, hyperparameters + ) workload.attach_metrics_logger(metrics_logger) global_start_time = get_time() @@ -336,42 +345,50 @@ def train_once( logging.info('Starting training loop.') goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) - while train_state['is_time_remaining'] and \ - not goals_reached and \ - not train_state['training_complete']: - + train_state['validation_goal_reached'] and train_state['test_goal_reached'] + ) + while ( + train_state['is_time_remaining'] + and not goals_reached + and not train_state['training_complete'] + ): step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, prep_eval_rng, eval_rng = \ - prng.split(step_rng, 4) + data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split( + step_rng, 4 + ) with profiler.profile('Data selection'): - batch = data_selection(workload, - input_queue, - optimizer_state, - model_params, - model_state, - hyperparameters, - global_step, - data_select_rng) + batch = data_selection( + workload, + input_queue, + optimizer_state, + model_params, + model_state, + hyperparameters, + global_step, + data_select_rng, + ) try: with profiler.profile('Update parameters'): optimizer_state, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - batch=batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=update_rng, - **({'train_state': MappingProxyType(train_state)} - if needs_train_state else {})) + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + batch=batch, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=update_rng, + **( + {'train_state': MappingProxyType(train_state)} + if needs_train_state + else {} + ), + ) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 @@ -381,121 +398,139 @@ def train_once( train_step_end_time = get_time() train_state['accumulated_submission_time'] += ( - train_step_end_time - train_state['last_step_end_time']) + train_step_end_time - train_state['last_step_end_time'] + ) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): - + if ( + train_step_end_time - train_state['last_eval_time'] + ) >= workload.eval_period_time_sec or train_state['training_complete']: # Prepare for evaluation (timed). if prepare_for_eval is not None: - with profiler.profile('Prepare for eval'): del batch prepare_for_eval_start_time = get_time() optimizer_state, model_params, model_state = prepare_for_eval( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=prep_eval_rng) + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng, + ) prepare_for_eval_end_time = get_time() # Update sumbission time. train_state['accumulated_submission_time'] += ( - prepare_for_eval_end_time - prepare_for_eval_start_time) + prepare_for_eval_end_time - prepare_for_eval_start_time + ) # Check if time is remaining, # use 1.5x the runtime budget for the self-tuning ruleset. max_allowed_runtime_sec = ( - workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 1.5 * workload.max_allowed_runtime_sec) + workload.max_allowed_runtime_sec + if FLAGS.tuning_ruleset == 'external' + else 1.5 * workload.max_allowed_runtime_sec + ) train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + train_state['accumulated_submission_time'] < max_allowed_runtime_sec + ) # Eval if time is remaining (untimed). if train_state['is_time_remaining']: - with profiler.profile('Evaluation'): _reset_cuda_mem() try: eval_start_time = get_time() - latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) + latest_eval_result = workload.eval_model( + global_eval_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step, + ) # Check if targets reached. # Note that this is one of the stopping conditions for the length of # a training run. To score the run we only consider the time # to validation target retrospectively. train_state['validation_goal_reached'] = ( - workload.has_reached_validation_target(latest_eval_result) or - train_state['validation_goal_reached']) + workload.has_reached_validation_target(latest_eval_result) + or train_state['validation_goal_reached'] + ) train_state['test_goal_reached'] = ( - workload.has_reached_test_target(latest_eval_result) or - train_state['test_goal_reached']) + workload.has_reached_test_target(latest_eval_result) + or train_state['test_goal_reached'] + ) goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) + train_state['validation_goal_reached'] + and train_state['test_goal_reached'] + ) # Save last eval time. eval_end_time = get_time() train_state['last_eval_time'] = eval_end_time # Accumulate eval time. - train_state[ - 'accumulated_eval_time'] += eval_end_time - eval_start_time + train_state['accumulated_eval_time'] += ( + eval_end_time - eval_start_time + ) # Add times to eval results for logging. - latest_eval_result['score'] = ( - train_state['accumulated_submission_time']) - latest_eval_result[ - 'total_duration'] = eval_end_time - global_start_time + latest_eval_result['score'] = train_state[ + 'accumulated_submission_time' + ] + latest_eval_result['total_duration'] = ( + eval_end_time - global_start_time + ) latest_eval_result['accumulated_submission_time'] = train_state[ - 'accumulated_submission_time'] + 'accumulated_submission_time' + ] latest_eval_result['accumulated_eval_time'] = train_state[ - 'accumulated_eval_time'] + 'accumulated_eval_time' + ] latest_eval_result['accumulated_logging_time'] = train_state[ - 'accumulated_logging_time'] + 'accumulated_logging_time' + ] time_since_start = latest_eval_result['total_duration'] - logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') + logging.info( + f'Time since start: {time_since_start:.2f}s, ' + f'\tStep: {global_step}, \t{latest_eval_result}' + ) eval_results.append((global_step, latest_eval_result)) logging_start_time = get_time() if log_dir is not None and RANK == 0: metrics_logger.append_scalar_metrics( - latest_eval_result, - global_step=global_step, - preemption_count=preemption_count, - is_eval=True, + latest_eval_result, + global_step=global_step, + preemption_count=preemption_count, + is_eval=True, ) if save_checkpoints: checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS - .save_intermediate_checkpoints) + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints, + ) logging_end_time = get_time() train_state['accumulated_logging_time'] += ( - logging_end_time - logging_start_time) + logging_end_time - logging_start_time + ) _reset_cuda_mem() @@ -503,8 +538,9 @@ def train_once( logging.exception(f'Eval step {global_step} error.\n') if 'out of memory' in str(e): logging.warning( - 'Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + 'Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.' + ) _reset_cuda_mem() train_state['last_step_end_time'] = get_time() @@ -513,41 +549,45 @@ def train_once( if log_dir is not None and RANK == 0: metrics_logger.append_scalar_metrics( - {'score': train_state['accumulated_submission_time']}, - global_step=global_step, - preemption_count=preemption_count) + {'score': train_state['accumulated_submission_time']}, + global_step=global_step, + preemption_count=preemption_count, + ) metrics_logger.finish() if save_checkpoints: checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints) + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints, + ) return train_state['accumulated_submission_time'], metrics -def score_submission_on_workload(workload: spec.Workload, - workload_name: str, - submission_path: str, - data_dir: str, - tuning_ruleset: str, - profiler: Optional[Profiler] = None, - max_global_steps: Optional[int] = None, - imagenet_v2_data_dir: Optional[str] = None, - tuning_search_space: Optional[str] = None, - num_tuning_trials: Optional[int] = None, - log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True, - hparam_start_index: Optional[bool] = None, - hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None): +def score_submission_on_workload( + workload: spec.Workload, + workload_name: str, + submission_path: str, + data_dir: str, + tuning_ruleset: str, + profiler: Optional[Profiler] = None, + max_global_steps: Optional[int] = None, + imagenet_v2_data_dir: Optional[str] = None, + tuning_search_space: Optional[str] = None, + num_tuning_trials: Optional[int] = None, + log_dir: Optional[str] = None, + save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None, + rng_seed: Optional[int] = None, +): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -571,18 +611,21 @@ def score_submission_on_workload(workload: spec.Workload, n_gpus = max(N_GPUS, jax.local_device_count()) if global_batch_size % n_gpus != 0: raise ValueError( - f'The global batch size ({global_batch_size}) has to be divisible by ' - f'the number of GPUs ({n_gpus}).') + f'The global batch size ({global_batch_size}) has to be divisible by ' + f'the number of GPUs ({n_gpus}).' + ) if hasattr(submission_module, 'get_eval_batch_size'): # If the user specifies the eval batch size, use the provided one. global_eval_batch_size = submission_module.get_eval_batch_size( - workload_name) + workload_name + ) else: global_eval_batch_size = workload.eval_batch_size if global_eval_batch_size % n_gpus != 0: raise ValueError( - f'The global eval batch size ({global_eval_batch_size}) has to be ' - f'divisible by the number of GPUs ({n_gpus}).') + f'The global eval batch size ({global_eval_batch_size}) has to be ' + f'divisible by the number of GPUs ({n_gpus}).' + ) if tuning_ruleset == 'external': # If the submission runner is responsible for hyperparameter tuning, load in @@ -590,15 +633,18 @@ def score_submission_on_workload(workload: spec.Workload, # settings from it. if tuning_search_space is None: raise ValueError( - 'Must provide a tuning search space JSON file when using external ' - 'tuning.') + 'Must provide a tuning search space JSON file when using external ' + 'tuning.' + ) with open(tuning_search_space, 'r', encoding='UTF-8') as search_space_file: tuning_search_space = halton.generate_search( - json.load(search_space_file), num_tuning_trials) + json.load(search_space_file), num_tuning_trials + ) all_timings = {} all_metrics = {} tuning_search_space_iter = itertools.islice( - enumerate(tuning_search_space), hparam_start_index, hparam_end_index) + enumerate(tuning_search_space), hparam_start_index, hparam_end_index + ) for hi, hyperparameters in tuning_search_space_iter: # Generate a new seed from hardware sources of randomness for each trial. if not rng_seed: @@ -622,25 +668,31 @@ def score_submission_on_workload(workload: spec.Workload, # If existing hyperparameter exists, use saved # hyperparameters for consistency. - hyperparameters = logger_utils.write_hparams(hyperparameters, - tuning_dir_name) + hyperparameters = logger_utils.write_hparams( + hyperparameters, tuning_dir_name + ) tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): - timing, metrics = train_once(workload, workload_name, - global_batch_size, - global_eval_batch_size, - data_dir, imagenet_v2_data_dir, - init_optimizer_state, - update_params, data_selection, - prepare_for_eval, - hyperparameters, - rng_seed, - rng, - profiler, - max_global_steps, - tuning_dir_name, - save_checkpoints=save_checkpoints,) + timing, metrics = train_once( + workload, + workload_name, + global_batch_size, + global_eval_batch_size, + data_dir, + imagenet_v2_data_dir, + init_optimizer_state, + update_params, + data_selection, + prepare_for_eval, + hyperparameters, + rng_seed, + rng, + profiler, + max_global_steps, + tuning_dir_name, + save_checkpoints=save_checkpoints, + ) all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') @@ -654,7 +706,8 @@ def score_submission_on_workload(workload: spec.Workload, else: if tuning_search_space is not None: raise ValueError( - 'Cannot provide a tuning search space when using self tuning.') + 'Cannot provide a tuning search space when using self tuning.' + ) if not rng_seed: rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) @@ -666,11 +719,24 @@ def score_submission_on_workload(workload: spec.Workload, logger_utils.makedir(log_dir) with profiler.profile('Train'): score, _ = train_once( - workload, workload_name, global_batch_size, global_eval_batch_size, - data_dir, imagenet_v2_data_dir, - init_optimizer_state, update_params, data_selection, prepare_for_eval, - None, rng_seed, rng, profiler, max_global_steps, log_dir, - save_checkpoints=save_checkpoints) + workload, + workload_name, + global_batch_size, + global_eval_batch_size, + data_dir, + imagenet_v2_data_dir, + init_optimizer_state, + update_params, + data_selection, + prepare_for_eval, + None, + rng_seed, + rng, + profiler, + max_global_steps, + log_dir, + save_checkpoints=save_checkpoints, + ) return score @@ -694,59 +760,66 @@ def main(_): # TODO: remove once issue resolved. if FLAGS.pytorch_eval_num_workers != 0: logging.warning( - 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' - 'in incorrect evals currently, see issues/732.') + 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' + 'in incorrect evals currently, see issues/732.' + ) workload_metadata = WORKLOADS[FLAGS.workload] if base_workload in [ - 'librispeech_conformer', - 'librispeech_deepspeech', - 'imagenet_vit', - 'criteo1tb' + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb', ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( - BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + f'_{FLAGS.framework}', - 'workload.py') + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'_{FLAGS.framework}', + 'workload.py', + ) workload_init_kwargs = {} if FLAGS.librispeech_tokenizer_vocab_path: workload_init_kwargs['tokenizer_vocab_path'] = ( - FLAGS.librispeech_tokenizer_vocab_path) + FLAGS.librispeech_tokenizer_vocab_path + ) workload = workloads.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs=workload_init_kwargs) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) experiment_name = FLAGS.experiment_name if experiment_name and FLAGS.append_timestamp: experiment_name += datetime.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S') - logging_dir_path = logger_utils.get_log_dir(FLAGS.experiment_dir, - FLAGS.workload, - FLAGS.framework, - experiment_name, - FLAGS.resume_last_run, - FLAGS.overwrite) + logging_dir_path = logger_utils.get_log_dir( + FLAGS.experiment_dir, + FLAGS.workload, + FLAGS.framework, + experiment_name, + FLAGS.resume_last_run, + FLAGS.overwrite, + ) score = score_submission_on_workload( - workload=workload, - workload_name=FLAGS.workload, - submission_path=FLAGS.submission_path, - data_dir=FLAGS.data_dir, - tuning_ruleset=FLAGS.tuning_ruleset, - profiler=profiler, - max_global_steps=FLAGS.max_global_steps, - imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir, - tuning_search_space=FLAGS.tuning_search_space, - num_tuning_trials=FLAGS.num_tuning_trials, - log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints, - hparam_start_index=FLAGS.hparam_start_index, - hparam_end_index=FLAGS.hparam_end_index, - rng_seed=FLAGS.rng_seed) + workload=workload, + workload_name=FLAGS.workload, + submission_path=FLAGS.submission_path, + data_dir=FLAGS.data_dir, + tuning_ruleset=FLAGS.tuning_ruleset, + profiler=profiler, + max_global_steps=FLAGS.max_global_steps, + imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir, + tuning_search_space=FLAGS.tuning_search_space, + num_tuning_trials=FLAGS.num_tuning_trials, + log_dir=logging_dir_path, + save_checkpoints=FLAGS.save_checkpoints, + hparam_start_index=FLAGS.hparam_start_index, + hparam_end_index=FLAGS.hparam_end_index, + rng_seed=FLAGS.rng_seed, + ) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: diff --git a/submissions/submission_checker.py b/submissions/submission_checker.py index ab657c0f0..f8af9fb52 100644 --- a/submissions/submission_checker.py +++ b/submissions/submission_checker.py @@ -29,7 +29,6 @@ import argparse import logging import os -import subprocess SELF_TUNING = 'self_tuning' EXTERNAL_TUNING = 'external_tuning' @@ -41,7 +40,8 @@ def _check_ruleset_subdirs(submission_dir): contents = os.listdir(submission_dir) if not ((EXTERNAL_TUNING in contents) or (SELF_TUNING in contents)): logging.info( - f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.') + f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.' + ) return False return True @@ -54,7 +54,7 @@ def _check_submission_module(submission_dir): contents = os.listdir(os.path.join(root, submission_dir)) if SUBMISSION_MODULE not in contents: logging.info( - f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}' + f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}' ) return False return True @@ -68,7 +68,7 @@ def _check_tuning_search_space_file(submission_dir): contents = os.listdir(os.path.join(root, submission_dir)) if TUNING_SEARCH_SPACE_FILENAME not in contents: logging.info( - f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}' + f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}' ) return False return True @@ -76,18 +76,22 @@ def _check_tuning_search_space_file(submission_dir): def run_checks(submission_dir): """Top-level checker function. - Call individual checkers from this function. - """ + Call individual checkers from this function. + """ logging.info('Running repository checks.') # Execute checks contains_ruleset_subdirs = _check_ruleset_subdirs(submission_dir) contains_submission_module = _check_submission_module(submission_dir) contains_tuning_search_space_file = _check_tuning_search_space_file( - submission_dir) + submission_dir + ) - if not (contains_ruleset_subdirs and contains_submission_module and - contains_tuning_search_space_file): + if not ( + contains_ruleset_subdirs + and contains_submission_module + and contains_tuning_search_space_file + ): logging.info('TESTS FAILED.') return False @@ -98,16 +102,17 @@ def run_checks(submission_dir): def get_parser(): """Parse commandline.""" parser = argparse.ArgumentParser( - description='Checks for submission folder for AlgoPerf',) + description='Checks for submission folder for AlgoPerf', + ) parser.add_argument( - 'folder', - type=str, - help='the folder for a submission package.', + 'folder', + type=str, + help='the folder for a submission package.', ) parser.add_argument( - '--log_output', - type=str, - default='submission_checker.log', + '--log_output', + type=str, + default='submission_checker.log', ) return parser @@ -118,7 +123,7 @@ def main(): logging.basicConfig(filename=args.log_output, level=logging.INFO) logging.getLogger().addHandler(logging.StreamHandler()) - formatter = logging.Formatter("%(levelname)s - %(message)s") + formatter = logging.Formatter('%(levelname)s - %(message)s') logging.getLogger().handlers[0].setFormatter(formatter) logging.getLogger().handlers[1].setFormatter(formatter) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index a4fdc62b4..db6900afd 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,96 +4,102 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ + from typing import Any, Dict, Iterator, List, Optional, Tuple from algoperf import spec -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Nesterov optimizer and a learning rate schedule. - Returns: - optimizer state - optimizer_update_fn - """ + Returns: spec.OptimizerState initialized optimizer state + """ pass def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """ + Returns: + spec.OptimizerState: new optimizer state + spec.ParameterTypeTree: new params + new_model_state: new model state """ - Returns: - (new_optimizer_state, update_fn) - new_params - new_model_state - """ pass -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """ + Returns: + new_optimizer_state + new_params + new_model_state """ - Returns: - new_optimizer_state - new_params - new_model_state - """ pass def get_batch_size(workload_name): """ - Gets batch size for workload. - Note that these batch sizes only apply during training and not during evals. - Args: - workload_name (str): Valid workload_name values are: "wmt", "ogbg", - "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", - "librispeech_deepspeech", "librispeech_conformer" or any of the - variants. - Returns: - int: batch_size - Raises: - ValueError: If workload_name is not handled. - """ + Gets batch size for workload. + Note that these batch sizes only apply during training and not during evals. + Args: + workload_name (str): Valid workload_name values are: "wmt", "ogbg", + "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", + "librispeech_deepspeech", "librispeech_conformer" or any of the + variants. + Returns: + int: batch_size + Raises: + ValueError: If workload_name is not handled. + """ pass -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - Tip: - If you would just like the next batch from the input queue return next(input_queue). + Each element of the queue is a batch of training examples and labels. + Tip: + If you would just like the next batch from the input queue return next(input_queue). - Returns: - batch: next batch of input data - """ + Returns: + batch: next batch of input data + """ pass diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 9de61a2a5..48f658d06 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -8,10 +8,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -54,31 +56,34 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index f1897d16f..897920bac 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -8,10 +8,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -53,31 +55,34 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 5aad3cc67..db1dec601 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -8,10 +8,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -65,31 +67,34 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - # mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + # mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 169b1cdf4..4851a8ad4 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ - Criteo1TbDlrmSmallResNetWorkload as JaxWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallResNetWorkload as JaxWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -65,9 +67,9 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = { - 'inputs': torch.ones((2, 13 + 26)), - 'targets': torch.randint(low=0, high=1, size=(2,)), - 'weights': torch.ones(2), + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), } init_fake_batch_size = 2 @@ -80,23 +82,26 @@ def sd_transform(sd): # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 52241fd3a..a74159028 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,21 +1,23 @@ -from flax import jax_utils -from flax.core import FrozenDict import jax import numpy as np import torch +from flax import jax_utils +from flax.core import FrozenDict -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform +from tests.modeldiffs.torch2jax_utils import Torch2Jax, value_transform -#pylint: disable=dangerous-default-value -def torch2jax(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) +# pylint: disable=dangerous-default-value +def torch2jax( + jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0), +): + jax_params, model_state = jax_workload.init_model_fn( + jax.random.PRNGKey(0), **init_kwargs + ) pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) if isinstance(jax_params, dict): jax_params = FrozenDict(jax_params) @@ -24,8 +26,9 @@ def torch2jax(jax_workload, model_state = jax_utils.unreplicate(model_state) if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel), + ): pytorch_model = pytorch_model.module # Map and copy params of pytorch_model to jax_model. t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) @@ -39,22 +42,24 @@ def torch2jax(jax_workload, return jax_params, model_state, pytorch_model -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) +def out_diff( + jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None, +): + jax_params, model_state, pytorch_model = torch2jax( + jax_workload, pytorch_workload, key_transform, sd_transform + ) + out_p, _ = pytorch_workload.model_fn( + params=pytorch_model, **pytorch_model_kwargs + ) + out_j, _ = jax_workload.model_fn( + params=jax_params, model_state=model_state, **jax_model_kwargs + ) if out_transform is not None: out_p = out_transform(out_p) out_j = out_transform(out_j) @@ -67,15 +72,16 @@ def out_diff(jax_workload, class ModelDiffRunner: - - def __init__(self, - jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None) -> None: + def __init__( + self, + jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None, + ) -> None: """ Initializes the instance based on diffing logic. @@ -83,7 +89,7 @@ def __init__(self, jax_workload: Workload implementation using JAX. pytorch_workload: Workload implementation using PyTorch. jax_model_kwargs: Arguments to be used for model_fn in jax workload. - pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch + pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload. key_transform: Transformation function for keys. sd_transform: Transformation function for State Dictionary. @@ -99,10 +105,12 @@ def __init__(self, self.out_transform = out_transform def run(self): - out_diff(self.jax_workload, - self.pytorch_workload, - self.jax_model_kwargs, - self.pytorch_model_kwargs, - self.key_transform, - self.sd_transform, - self.out_transform) + out_diff( + self.jax_workload, + self.pytorch_workload, + self.jax_model_kwargs, + self.pytorch_model_kwargs, + self.key_transform, + self.sd_transform, + self.out_transform, + ) diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index c1a349cec..6a82bfb58 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRIWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -34,7 +35,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -64,22 +65,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index f26ad185e..8ad47bcae 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRILayerNormWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRILayerNormWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRILayerNormWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRILayerNormWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -35,7 +36,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -71,22 +72,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index 42789539b..f6d5c5074 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRIModelSizeWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIModelSizeWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIModelSizeWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIModelSizeWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -34,7 +35,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -64,22 +65,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 13ecb890c..714a025b3 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -7,15 +7,16 @@ import torch from algoperf import spec -from algoperf.workloads.fastmri.fastmri_jax.workload import \ - FastMRITanhWorkload as JaxWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRITanhWorkload as PyTorchWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRITanhWorkload as JaxWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRITanhWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): - def sort_key(k): if k[0] == 'ModuleList_0': return (0, *k) @@ -34,7 +35,7 @@ def sort_key(k): if 'ModuleList' in i or 'Sequential' in i: continue if i.startswith('ConvBlock'): - if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]: + if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]: c += 1 i = f'ConvBlock_{c}' if 'Conv2d' in i: @@ -64,22 +65,25 @@ def sort_key(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=None, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=None, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 59ab45555..e43cd069e 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetWorkload as JaxWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -53,11 +55,13 @@ def sd_transform(sd): c += 1 new_key = (f'BottleneckResNetBlock_{c}',) + k[2:] if 'Sequential' in ''.join(new_key): - new_key = tuple([ + new_key = tuple( + [ (i.replace('_0', '_proj') if 'BatchNorm' in i or 'Conv' in i else i) for i in new_key if 'Sequential' not in i - ]) + ] + ) sd[new_key] = sd[k] del sd[k] elif 'BatchNorm' in k[0] or 'Conv' in k[0]: @@ -81,22 +85,25 @@ def sd_transform(sd): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 07510ad70..a92712ddc 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -7,13 +7,14 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetGELUWorkload as JaxWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetGELUWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetGELUWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetGELUWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.imagenet_resnet.compare import key_transform -from tests.modeldiffs.imagenet_resnet.compare import sd_transform +from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable @@ -28,22 +29,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 8246d17a2..bbbfd082b 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -7,13 +7,14 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetSiLUWorkload as JaxWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetSiLUWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetSiLUWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetSiLUWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.imagenet_resnet.compare import key_transform -from tests.modeldiffs.imagenet_resnet.compare import sd_transform +from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable @@ -28,22 +29,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index b4ca7d8ec..84282d4be 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -40,16 +42,16 @@ def key_transform(k): if attention: if pool_head: i = { - 'Linear_0': 'query', - 'Linear_1': 'key_value', - 'Linear_2': 'out', + 'Linear_0': 'query', + 'Linear_1': 'key_value', + 'Linear_2': 'out', }[i] else: i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', }[i] else: i = i.replace('Linear', 'Dense') @@ -94,22 +96,25 @@ def key_transform(k): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ).run() diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index c152410b5..55c010b97 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -10,10 +10,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitGluWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitGluWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitGluWorkload as PyTorchWorkload, +) sd_transform = None @@ -30,22 +32,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ).run() diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index 7f1af41ab..17a7483c2 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -10,10 +10,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitMapWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitMapWorkload as PytWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitMapWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitMapWorkload as PytWorkload, +) def sd_transform(sd): @@ -41,22 +43,25 @@ def sd_transform(sd): pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ).run() diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index a3a639101..72d407a4b 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -10,10 +10,12 @@ import torch from algoperf import spec -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitPostLNWorkload as JaxWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PyTorchWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitPostLNWorkload as JaxWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetViTPostLNWorkload as PyTorchWorkload, +) sd_transform = None @@ -30,22 +32,25 @@ pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ).run() diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index 664b1242d..80aba62f2 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index b0812e77d..a7bebf6bf 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index 3032a0005..d8f7980a2 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerGeluWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerGeluWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerGeluWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerGeluWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index d623ef352..7f4768c11 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerLayerNormWorkload as JaxWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerLayerNormWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerLayerNormWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerLayerNormWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -69,24 +71,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 84b0a6c86..bd1073524 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -7,10 +7,12 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner @@ -56,9 +58,9 @@ def sd_transform(sd): else: out[k] = sd[k] elif 'LSTM' in ''.join(k): - l = out.get(k[:-1], dict()) - l[k[-1]] = sd[k] - out[k[:-1]] = l + l_tmp = out.get(k[:-1], dict()) + l_tmp[k[-1]] = sd[k] + out[k[:-1]] = l_tmp else: out[k] = sd[k] keys_to_del = [] @@ -67,10 +69,12 @@ def sd_transform(sd): if isinstance(out[k], dict): kernels = ['kernel_ih_l0', 'kernel_hh_l0'] biases = ['bias_ih_l0', 'bias_hh_l0'] - weights = torch.cat([out[k][i].view(-1) for i in kernels] + - [out[k][i + '_reverse'].view(-1) for i in kernels] + - [out[k][i].view(-1) for i in biases] + - [out[k][i + '_reverse'].view(-1) for i in biases]) + weights = torch.cat( + [out[k][i].view(-1) for i in kernels] + + [out[k][i + '_reverse'].view(-1) for i in kernels] + + [out[k][i].view(-1) for i in biases] + + [out[k][i + '_reverse'].view(-1) for i in biases] + ) updates[k + ('weights',)] = weights keys_to_del.append(k) out.update(updates) @@ -94,24 +98,27 @@ def sd_transform(sd): pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index 2540c1b93..8593894e4 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -7,13 +7,17 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechTanhWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechTanhWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.librispeech_deepspeech.compare import key_transform -from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform +from tests.modeldiffs.librispeech_deepspeech.compare import ( + key_transform, + sd_transform, +) if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable @@ -30,24 +34,27 @@ pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index e5972120d..27e4760a6 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -7,13 +7,17 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.librispeech_deepspeech.compare import key_transform -from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform +from tests.modeldiffs.librispeech_deepspeech.compare import ( + key_transform, + sd_transform, +) if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable @@ -30,24 +34,27 @@ pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 4d2c4a5d5..7990f063b 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -7,13 +7,17 @@ import torch from algoperf import spec -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ - LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner -from tests.modeldiffs.librispeech_deepspeech.compare import key_transform -from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform +from tests.modeldiffs.librispeech_deepspeech.compare import ( + key_transform, + sd_transform, +) if __name__ == '__main__': # pylint: disable=locally-disabled, not-callable @@ -30,24 +34,27 @@ pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] + * (1 - out_outpad[1][:, :, None]), + ).run() diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 5d5ef50bf..8c40d3c8a 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index fc3992998..f35ed8b17 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgGeluWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgGeluWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgGeluWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgGeluWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index e7cfa745c..0042c71af 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgModelSizeWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgModelSizeWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgModelSizeWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgModelSizeWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 4e3b96cf7..7583282cd 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -9,10 +9,12 @@ import torch from algoperf import spec -from algoperf.workloads.ogbg.ogbg_jax.workload import \ - OgbgSiluWorkload as JaxWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgSiluWorkload as PyTorchWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgSiluWorkload as JaxWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgSiluWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way @@ -41,8 +43,8 @@ def key_transform(k): layer_index = int(i.split('_')[1]) if graph_network: count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + - layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -54,7 +56,8 @@ def key_transform(k): elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = ( - graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + ) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: @@ -80,13 +83,14 @@ def sd_transform(sd): pytorch_workload = PyTorchWorkload() pytorch_batch = dict( - n_node=torch.LongTensor([5]), - n_edge=torch.LongTensor([5]), - nodes=torch.randn(5, 9), - edges=torch.randn(5, 3), - globals=torch.randn(1, 128), - senders=torch.LongTensor(list(range(5))), - receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]), + ) jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} @@ -98,23 +102,26 @@ def sd_transform(sd): pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index d9264b400..4c95ca7e4 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -1,5 +1,5 @@ -from collections import Counter import pprint +from collections import Counter def jax_like_pytorch_statedict(model, state_dict, keys=None): @@ -32,13 +32,17 @@ def flatten(jm, ret, keys=None): def value_transform(k, value, jax_value): k_str = ''.join(k).lower() - if ('conv' in k_str and 'kernel' in k_str) or \ - ('embedding' in k_str and 'kernel' in k_str): + if ('conv' in k_str and 'kernel' in k_str) or ( + 'embedding' in k_str and 'kernel' in k_str + ): if 'transpose' in k_str: # Assumes 2D ConvTranspose with stride equal to kernel_size. - return value.reshape(value.shape[0], value.shape[1], - -1).flip(-1).permute(2, 0, - 1).reshape(*jax_value.shape) + return ( + value.reshape(value.shape[0], value.shape[1], -1) + .flip(-1) + .permute(2, 0, 1) + .reshape(*jax_value.shape) + ) else: rank = len(value.shape) if rank == 3: @@ -51,16 +55,17 @@ def value_transform(k, value, jax_value): value = value.t().reshape(*list(jax_value.shape)) elif 'attention' in k_str and 'bias' in k_str: value = value.reshape(*list(jax_value.shape)) - elif ('dense' in k_str and 'kernel' in k_str) or \ - ('lstm' in k_str and 'kernel' in k_str) or \ - ('head' in k_str and 'kernel' in k_str) or \ - ('pre_logits' in k_str and 'kernel' in k_str): + elif ( + ('dense' in k_str and 'kernel' in k_str) + or ('lstm' in k_str and 'kernel' in k_str) + or ('head' in k_str and 'kernel' in k_str) + or ('pre_logits' in k_str and 'kernel' in k_str) + ): value = value.t() return value class Torch2Jax: - def __init__(self, torch_model, jax_model): self.torch_model = torch_model self.jax_model = jax_model @@ -73,13 +78,13 @@ def __init__(self, torch_model, jax_model): def key_transform(self, k_transform_fn): self.pytorch_sd = { - k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd + k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd } def value_transform(self, v_transform_fn): self.pytorch_sd = { - k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) - for k in self.pytorch_sd + k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) + for k in self.pytorch_sd } def sd_transform(self, sd_transform_fn): diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index 5595894e6..aa7bebd4f 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -1,28 +1,33 @@ -from flax import jax_utils import jax import jax.numpy as jnp import optax +from flax import jax_utils from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401 + update_params, +) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Vanilla SGD Optimizer.""" del model_params del model_state del rng # Create optimizer. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes + ) opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001) optimizer_state = opt_init_fn(params_zeros_like) diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py index a6a0c5fa6..6448ac097 100644 --- a/tests/modeldiffs/vanilla_sgd_pytorch.py +++ b/tests/modeldiffs/vanilla_sgd_pytorch.py @@ -1,24 +1,29 @@ import torch from algoperf import spec -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401 + data_selection, +) +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401 + update_params, +) -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: """Creates a Vanilla SGD Optimizer.""" del model_state del rng optimizer_state = { - 'optimizer': - torch.optim.SGD(model_params.parameters(), lr=0.001, weight_decay=0), + 'optimizer': torch.optim.SGD( + model_params.parameters(), lr=0.001, weight_decay=0 + ), } return optimizer_state diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 109bfa629..02175c8b5 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -8,17 +8,20 @@ from algoperf import spec from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkload as PyTorchWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkload as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -60,7 +63,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -73,11 +76,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -112,29 +115,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 1aa20fe3b..0c834dc86 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -7,19 +7,23 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkloadAttentionTemp as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadAttentionTemp as PyTorchWorkload +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkloadAttentionTemp as JaxWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkloadAttentionTemp as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -61,7 +65,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -74,11 +78,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -113,29 +117,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index e98a6945d..f7de12326 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -7,19 +7,23 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkloadGLUTanH as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadGLUTanH as PyTorchWorkload +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkloadGLUTanH as JaxWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkloadGLUTanH as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -61,7 +65,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -74,11 +78,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -113,29 +117,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index d110715b5..a8681ca8e 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -7,19 +7,23 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkloadPostLN as JaxWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadPostLN as PyTorchWorkload +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkloadPostLN as JaxWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkloadPostLN as PyTorchWorkload, +) from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): new_key = [] for i in k: - if 'ModuleList' in i or\ - 'TransformerDecoder_' in i or\ - 'TransformerEncoder_' in i: + if ( + 'ModuleList' in i + or 'TransformerDecoder_' in i + or 'TransformerEncoder_' in i + ): continue if 'Linear' in i: if 'NonDynamicallyQuantizableLinear' in i: @@ -61,7 +65,7 @@ def sd_transform(sd): pass else: if new_key[-2] == 'Dense_0': - #q + # q out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] pass elif new_key[-2] == 'Dense_1': @@ -74,11 +78,11 @@ def sd_transform(sd): out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass out = { - tuple( - k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key + ): value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) @@ -113,29 +117,32 @@ def sd_transform(sd): tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256)) jax_batch = { - 'inputs': inp_tokens.detach().numpy(), - 'targets': tgt_tokens.detach().numpy(), + 'inputs': inp_tokens.detach().numpy(), + 'targets': tgt_tokens.detach().numpy(), } pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pytorch_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) + augmented_and_preprocessed_input_batch=pytorch_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False, + ) jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False, + ) ModelDiffRunner( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=sd_transform, - out_transform=None).run() + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None, + ).run() diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..d17848aaf 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -27,67 +27,66 @@ import os import pickle -from absl import flags -from absl import logging -from absl.testing import absltest import flax -from flax import jax_utils -from flax.core.frozen_dict import FrozenDict import jax -from jraph import GraphsTuple import numpy as np import tensorflow as tf import torch import torch.distributed as dist +from absl import flags, logging +from absl.testing import absltest +from flax import jax_utils +from flax.core.frozen_dict import FrozenDict +from jraph import GraphsTuple -from algoperf import halton -from algoperf import pytorch_utils +import submission_runner +from algoperf import halton, pytorch_utils from algoperf import random_utils as prng from algoperf.profiler import PassThroughProfiler from algoperf.workloads import workloads from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map -import submission_runner from tests.modeldiffs import diff as diff_utils flags.DEFINE_integer( - 'global_batch_size', - -1, - ('Global Batch size to use when running an individual workload. Otherwise ' - 'a per-device batch size of 2 is used.')) + 'global_batch_size', + -1, + ( + 'Global Batch size to use when running an individual workload. Otherwise ' + 'a per-device batch size of 2 is used.' + ), +) flags.DEFINE_integer('num_train_steps', 1, 'Number of steps to train.') flags.DEFINE_boolean('use_fake_input_queue', True, 'Use fake data examples.') flags.DEFINE_string('log_file', '/tmp/log.pkl', 'The log file') flags.DEFINE_boolean( - 'all', - False, - 'Run all workloads instead of using --workload and --framework.') -flags.DEFINE_boolean('identical', - False, - 'Run jax and pytorch with identical weights.') + 'all', False, 'Run all workloads instead of using --workload and --framework.' +) +flags.DEFINE_boolean( + 'identical', False, 'Run jax and pytorch with identical weights.' +) FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, PYTORCH_DEVICE, N_GPUS = pytorch_utils.pytorch_setup() tf.config.set_visible_devices([], 'GPU') _EXPECTED_METRIC_NAMES = { - 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'], - 'criteo1tb': ['train/loss', 'validation/loss'], - 'criteo1tb_test': ['train/loss', 'validation/loss'], - 'fastmri': ['train/ssim', 'validation/ssim'], - 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'], - 'imagenet_vit': ['train/accuracy', 'validation/accuracy'], - 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'], - 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'], - 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'], - 'ogbg': [ - 'train/accuracy', 'validation/loss', 'test/mean_average_precision' - ], - 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'], + 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'], + 'criteo1tb': ['train/loss', 'validation/loss'], + 'criteo1tb_test': ['train/loss', 'validation/loss'], + 'fastmri': ['train/ssim', 'validation/ssim'], + 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'], + 'imagenet_vit': ['train/accuracy', 'validation/accuracy'], + 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'], + 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'], + 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'], + 'ogbg': ['train/accuracy', 'validation/loss', 'test/mean_average_precision'], + 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'], } def _make_fake_image_batch(batch_shape, data_shape, num_classes): - examples = np.random.normal(size=(*batch_shape, - *data_shape)).astype(np.float32) + examples = np.random.normal(size=(*batch_shape, *data_shape)).astype( + np.float32 + ) labels = np.random.randint(0, num_classes, size=batch_shape) masks = np.ones(batch_shape, dtype=np.float32) return {'inputs': examples, 'targets': labels, 'weights': masks} @@ -96,16 +95,17 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes): def _pytorch_map(inputs): if USE_PYTORCH_DDP: return jax.tree.map( - lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs) + lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs + ) return jax.tree.map( - lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) - if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view( - -1), - inputs) + lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) + if len(a.shape) == 3 + else torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1), + inputs, + ) class _FakeTokenizer: - def detokenize(self, *args): del args return tf.constant('this is a fake sequence?') @@ -113,15 +113,14 @@ def detokenize(self, *args): @flax.struct.dataclass class _FakeMetricsCollection: - def merge(self, *args): del args return self def compute(self): return { - 'wer': 0.0, - 'ctc_loss': 0.0, + 'wer': 0.0, + 'ctc_loss': 0.0, } def unreplicate(self): @@ -129,7 +128,6 @@ def unreplicate(self): class _FakeMetricsLogger: - def __init__(self): self.filename = FLAGS.log_file self.scalars = [] @@ -152,27 +150,27 @@ def append_eval_metrics(self, result): def save(self): with open(self.filename, 'wb') as f: - pickle.dump({'scalars': self.scalars, 'eval_results': self.eval_results}, - f) + pickle.dump( + {'scalars': self.scalars, 'eval_results': self.eval_results}, f + ) class _FakeMetricsBundle: - def gather_from_model_output(self, *args, **kwargs): del args del kwargs return _FakeMetricsCollection() -def _make_one_batch_workload(workload_class, - workload_name, - framework, - global_batch_size, - use_fake_input_queue, - n_gpus): - +def _make_one_batch_workload( + workload_class, + workload_name, + framework, + global_batch_size, + use_fake_input_queue, + n_gpus, +): class _OneEvalBatchWorkload(workload_class): - def __init__(self): kwargs = {} if 'librispeech' in workload_name: @@ -184,24 +182,30 @@ def __init__(self): if 'librispeech' in workload_name: self.tokenizer = _FakeTokenizer() - def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): + def init_model_fn(self, rng): # pylint: disable=line-too-long - if not (FLAGS.identical and - os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')): - return super().init_model_fn( - rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate) + if not ( + FLAGS.identical + and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py') + ): + return super().init_model_fn(rng) if framework == 'jax': compare_module = importlib.import_module( - f'tests.modeldiffs.{workload_name}.compare') + f'tests.modeldiffs.{workload_name}.compare' + ) jax_params, model_state, _ = diff_utils.torch2jax( jax_workload=super(), pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs), key_transform=compare_module.key_transform, - sd_transform=compare_module.sd_transform) - return (FrozenDict(**jax_utils.replicate(jax_params)), - FrozenDict(**jax_utils.replicate(model_state)) - if model_state is not None else model_state) - return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0) + sd_transform=compare_module.sd_transform, + ) + return ( + FrozenDict(**jax_utils.replicate(jax_params)), + FrozenDict(**jax_utils.replicate(model_state)) + if model_state is not None + else model_state, + ) + return super().init_model_fn([0]) @property def num_eval_train_examples(self): @@ -236,73 +240,74 @@ def _build_input_queue(self, *args, **kwargs): else: data_shape = (3, 32, 32) fake_batch = _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=10) + batch_shape, data_shape=data_shape, num_classes=10 + ) elif workload_name == 'criteo1tb' or workload_name == 'criteo1tb_test': targets = np.ones(batch_shape) targets[0] = 0 fake_batch = { - 'inputs': np.ones((*batch_shape, 13 + 26)), - 'targets': targets, - 'weights': np.ones(batch_shape), + 'inputs': np.ones((*batch_shape, 13 + 26)), + 'targets': targets, + 'weights': np.ones(batch_shape), } elif workload_name in ['imagenet_resnet', 'imagenet_vit']: data_shape = (224, 224, 3) fake_batch = _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=1000) + batch_shape, data_shape=data_shape, num_classes=1000 + ) if framework == 'pytorch': num_dims = len(fake_batch['inputs'].shape) fake_batch['inputs'] = fake_batch['inputs'].transpose( - (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2)) + (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2) + ) elif 'librispeech' in workload_name: rate = 16000 - l = None - while l is None or l.shape[-1] < 320000: + audio_signal = None + while audio_signal is None or audio_signal.shape[-1] < 320000: duration = 0.5 - freq = 2**(np.random.rand(*batch_shape, 1) * 13) + freq = 2 ** (np.random.rand(*batch_shape, 1) * 13) wav = np.sin(2 * np.pi * freq * np.arange(rate * duration) / rate) - if l is None: - l = wav + if audio_signal is None: + audio_signal = wav else: - l = np.concatenate([l, wav], axis=-1) - inputs = l + audio_signal = np.concatenate([audio_signal, wav], axis=-1) + inputs = audio_signal targets = np.random.randint(low=1, high=1024, size=(*batch_shape, 256)) tgt_pad = np.arange(0, 256)[tuple([None] * len(batch_shape))] tgt_lengths = np.random.randint( - low=100, high=256, size=(*batch_shape, 1)) + low=100, high=256, size=(*batch_shape, 1) + ) tgt_pad = 1 * (tgt_pad > tgt_lengths) fake_batch = { - 'inputs': (inputs, np.zeros_like(inputs)), - 'targets': (targets, tgt_pad), + 'inputs': (inputs, np.zeros_like(inputs)), + 'targets': (targets, tgt_pad), } elif workload_name == 'mnist': fake_batch = _make_fake_image_batch( - batch_shape, data_shape=(28, 28, 1), num_classes=10) + batch_shape, data_shape=(28, 28, 1), num_classes=10 + ) elif workload_name == 'ogbg': tf.random.set_seed(5) def _fake_iter(): while True: fake_batch = { - 'num_nodes': - tf.ones((1,), dtype=tf.int64), - 'edge_index': - tf.ones((1, 2), dtype=tf.int64), - 'node_feat': - tf.random.normal((1, 9)), - 'edge_feat': - tf.random.normal((1, 3)), - 'labels': - tf.cast( - tf.random.uniform((self._num_outputs,), - minval=0, - maxval=2, - dtype=tf.int32), - tf.float32), + 'num_nodes': tf.ones((1,), dtype=tf.int64), + 'edge_index': tf.ones((1, 2), dtype=tf.int64), + 'node_feat': tf.random.normal((1, 9)), + 'edge_feat': tf.random.normal((1, 3)), + 'labels': tf.cast( + tf.random.uniform( + (self._num_outputs,), minval=0, maxval=2, dtype=tf.int32 + ), + tf.float32, + ), } yield fake_batch fake_batch_iter = ogbg_input_pipeline._get_batch_iterator( - _fake_iter(), global_batch_size) + _fake_iter(), global_batch_size + ) fake_batch = next(fake_batch_iter) # pylint: disable=stop-iteration-return if framework == 'pytorch': fake_batch['inputs'] = _graph_map(_pytorch_map, fake_batch['inputs']) @@ -311,48 +316,49 @@ def _fake_iter(): elif workload_name == 'wmt': max_len = 256 fake_batch = { - 'inputs': - np.random.randint( - low=0, high=32000, size=(*batch_shape, max_len)), - 'targets': - np.random.randint( - low=0, high=32000, size=(*batch_shape, max_len)), - 'weights': - np.random.randint(low=0, high=2, size=(*batch_shape, max_len)), + 'inputs': np.random.randint( + low=0, high=32000, size=(*batch_shape, max_len) + ), + 'targets': np.random.randint( + low=0, high=32000, size=(*batch_shape, max_len) + ), + 'weights': np.random.randint( + low=0, high=2, size=(*batch_shape, max_len) + ), } self._tokenizer = _FakeTokenizer() elif workload_name == 'fastmri': data_shape = (320, 320) fake_batch = { - 'inputs': - _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=1000) - ['inputs'], - 'targets': - _make_fake_image_batch( - batch_shape, data_shape=data_shape, num_classes=1000) - ['inputs'], - 'mean': - np.zeros(batch_shape), - 'std': - np.ones(batch_shape), - 'volume_max': - np.zeros(batch_shape), - 'weights': - np.ones(batch_shape), + 'inputs': _make_fake_image_batch( + batch_shape, data_shape=data_shape, num_classes=1000 + )['inputs'], + 'targets': _make_fake_image_batch( + batch_shape, data_shape=data_shape, num_classes=1000 + )['inputs'], + 'mean': np.zeros(batch_shape), + 'std': np.ones(batch_shape), + 'volume_max': np.zeros(batch_shape), + 'weights': np.ones(batch_shape), } else: raise ValueError( - 'Workload {} does not have a fake batch defined, you ' - 'can add it or use --use_fake_input_queue=false.'.format( - workload_name)) + 'Workload {} does not have a fake batch defined, you ' + 'can add it or use --use_fake_input_queue=false.'.format( + workload_name + ) + ) if framework == 'pytorch': def to_device(k, v): dtype = ( - torch.long if (k == 'targets' and workload_name != 'fastmri') else - torch.bool if k == 'weights' else torch.float) + torch.long + if (k == 'targets' and workload_name != 'fastmri') + else torch.bool + if k == 'weights' + else torch.float + ) if USE_PYTORCH_DDP: v = v[RANK] return torch.as_tensor(v, device=PYTORCH_DEVICE, dtype=dtype) @@ -388,24 +394,28 @@ def eval_model(self, *args, **kwargs): return _OneEvalBatchWorkload() -def _test_submission(workload_name, - framework, - submission_path, - search_space_path, - data_dir, - use_fake_input_queue, - n_gpus): +def _test_submission( + workload_name, + framework, + submission_path, + search_space_path, + data_dir, + use_fake_input_queue, + n_gpus, +): logging.info(f'========= Testing {workload_name} in {framework}.') FLAGS.framework = framework workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload_name]) workload_metadata['workload_path'] = os.path.join( - submission_runner.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + submission_runner.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py', + ) workload_class = workloads.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - return_class=True) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + return_class=True, + ) print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) @@ -422,30 +432,32 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') - workload = _make_one_batch_workload(workload_class, - workload_name, - framework, - global_batch_size, - use_fake_input_queue, - n_gpus) + workload = _make_one_batch_workload( + workload_class, + workload_name, + framework, + global_batch_size, + use_fake_input_queue, + n_gpus, + ) # Get a sample hyperparameter setting. hyperparameters = {} if search_space_path != 'None': with open(search_space_path, 'r', encoding='UTF-8') as search_space_file: hyperparameters = halton.generate_search( - json.load(search_space_file), num_trials=1)[0] + json.load(search_space_file), num_trials=1 + )[0] rng = prng.PRNGKey(0) data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) input_queue = workload._build_input_queue( - data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size) + data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size + ) model_params, model_state = workload.init_model_fn(model_init_rng) - optimizer_state = init_optimizer_state(workload, - model_params, - model_state, - hyperparameters, - opt_init_rng) + optimizer_state = init_optimizer_state( + workload, model_params, model_state, hyperparameters, opt_init_rng + ) if USE_PYTORCH_DDP: torch.cuda.empty_cache() @@ -453,44 +465,49 @@ def _test_submission(workload_name, for global_step in range(FLAGS.num_train_steps): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - batch = data_selection(workload, - input_queue, - optimizer_state, - model_params, - model_state, - hyperparameters, - global_step, - data_select_rng) + batch = data_selection( + workload, + input_queue, + optimizer_state, + model_params, + model_state, + hyperparameters, + global_step, + data_select_rng, + ) optimizer_state, model_params, model_state = update_params( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - batch=batch, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - train_state={}, - eval_results=[], - global_step=global_step, - rng=update_rng) + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + batch=batch, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + train_state={}, + eval_results=[], + global_step=global_step, + rng=update_rng, + ) eval_result = workload.eval_model( - global_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir=None, - global_step=global_step) - _ = workload.eval_model( global_batch_size, model_params, model_state, eval_rng, data_dir, imagenet_v2_data_dir=None, - global_step=global_step) + global_step=global_step, + ) + _ = workload.eval_model( + global_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir=None, + global_step=global_step, + ) return eval_result @@ -500,12 +517,15 @@ def _make_paths(repo_location, framework, workload_name): else: dataset_name = workload_name workload_dir = ( - f'{repo_location}/reference_algorithms/target_setting_algorithms/' - f'{workload_name}') + f'{repo_location}/reference_algorithms/target_setting_algorithms/' + f'{workload_name}' + ) search_space_path = f'{workload_dir}/tuning_search_space.json' - submission_path = (f'reference_algorithms/target_setting_algorithms/' - f'{workload_name}/{dataset_name}_{framework}/' - 'submission.py') + submission_path = ( + f'reference_algorithms/target_setting_algorithms/' + f'{workload_name}/{dataset_name}_{framework}/' + 'submission.py' + ) full_submission_path = f'{repo_location}/{submission_path}' if not os.path.exists(full_submission_path): return None, None @@ -535,7 +555,8 @@ def test_submission(self): if FLAGS.tuning_search_space: raise ValueError('Cannot set --tuning_search_space and --all.') references_dir = ( - f'{repo_location}/reference_algorithms/target_setting_algorithms') + f'{repo_location}/reference_algorithms/target_setting_algorithms' + ) for workload_name in os.listdir(references_dir): for framework in ['jax', 'pytorch']: if framework == 'pytorch': @@ -543,17 +564,19 @@ def test_submission(self): # First jax operation has to be called after pytorch_init. n_gpus = max(N_GPUS, jax.local_device_count()) search_space_path, submission_path = _make_paths( - repo_location, framework, workload_name) + repo_location, framework, workload_name + ) if search_space_path is None: continue eval_result = _test_submission( - workload_name, - framework, - submission_path, - search_space_path, - data_dir=FLAGS.data_dir, - use_fake_input_queue=FLAGS.use_fake_input_queue, - n_gpus=n_gpus) + workload_name, + framework, + submission_path, + search_space_path, + data_dir=FLAGS.data_dir, + use_fake_input_queue=FLAGS.use_fake_input_queue, + n_gpus=n_gpus, + ) self._assert_eval_result(workload_name, eval_result) else: framework = FLAGS.framework @@ -567,15 +590,17 @@ def test_submission(self): submission_path = FLAGS.submission_path else: search_space_path, submission_path = _make_paths( - repo_location, framework, workload_name) + repo_location, framework, workload_name + ) eval_result = _test_submission( - workload_name, - framework, - submission_path, - search_space_path, - data_dir=FLAGS.data_dir, - use_fake_input_queue=FLAGS.use_fake_input_queue, - n_gpus=n_gpus) + workload_name, + framework, + submission_path, + search_space_path, + data_dir=FLAGS.data_dir, + use_fake_input_queue=FLAGS.use_fake_input_queue, + n_gpus=n_gpus, + ) self._assert_eval_result(workload_name, eval_result) if USE_PYTORCH_DDP: diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index ff724b201..c6c993b7b 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -4,17 +4,16 @@ dataset to be available. For testing the workload and reference submission code for all workloads, see reference_algorithm_tests.py. """ + import copy import os import sys -from absl import flags -from absl import logging -from absl.testing import absltest -from absl.testing import parameterized +from absl import flags, logging +from absl.testing import absltest, parameterized -from algoperf.profiler import PassThroughProfiler import submission_runner +from algoperf.profiler import PassThroughProfiler FLAGS = flags.FLAGS # Needed to avoid UnparsedFlagAccessError @@ -28,47 +27,46 @@ class SubmissionRunnerTest(parameterized.TestCase): """Tests for reference submissions.""" @parameterized.named_parameters( - dict( - testcase_name='mnist_jax', - workload='mnist', - framework='jax', - submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'), - tuning_search_space=( - f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')), - dict( - testcase_name='mnist_pytorch', - workload='mnist', - framework='pytorch', - submission_path=( - f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'), - tuning_search_space=( - f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')), + dict( + testcase_name='mnist_jax', + workload='mnist', + framework='jax', + submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'), + tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'), + ), + dict( + testcase_name='mnist_pytorch', + workload='mnist', + framework='pytorch', + submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'), + tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'), + ), ) - def test_submission(self, - workload, - framework, - submission_path, - tuning_search_space): + def test_submission( + self, workload, framework, submission_path, tuning_search_space + ): FLAGS.framework = framework workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload]) workload_metadata['workload_path'] = os.path.join( - submission_runner.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + submission_runner.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py', + ) workload_obj = submission_runner.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs={}) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}, + ) score = submission_runner.score_submission_on_workload( - workload_obj, - workload, - submission_path, - data_dir='~/tensorflow_datasets', # The default in TFDS. - tuning_ruleset='external', - tuning_search_space=tuning_search_space, - num_tuning_trials=1, - profiler=PassThroughProfiler(), - max_global_steps=500, + workload_obj, + workload, + submission_path, + data_dir='~/tensorflow_datasets', # The default in TFDS. + tuning_ruleset='external', + tuning_search_space=tuning_search_space, + num_tuning_trials=1, + profiler=PassThroughProfiler(), + max_global_steps=500, ) logging.info(score) diff --git a/tests/test_baselines.py b/tests/test_baselines.py index b2be8aa11..c5097a567 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -1,20 +1,19 @@ """Tests for submission.py for baselines. -This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that +This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that requires the dataset to be available. """ + import copy import os import sys -from absl import flags -from absl import logging -from absl.testing import absltest -from absl.testing import parameterized +from absl import flags, logging +from absl.testing import absltest, parameterized +import submission_runner from algoperf.profiler import PassThroughProfiler from algoperf.workloads import workloads -import submission_runner FLAGS = flags.FLAGS # Needed to avoid UnparsedFlagAccessError @@ -24,41 +23,42 @@ MAX_GLOBAL_STEPS = 5 baselines = { - 'jax': [ - 'adafactor', - 'adamw', - 'lamb', - 'momentum', - 'nadamw', - 'nesterov', - 'sam', - 'shampoo', - ], - 'pytorch': [ - 'adamw', - 'momentum', - 'nadamw', - 'nesterov', - ], + 'jax': [ + 'adafactor', + 'adamw', + 'lamb', + 'momentum', + 'nadamw', + 'nesterov', + 'sam', + 'shampoo', + ], + 'pytorch': [ + 'adamw', + 'momentum', + 'nadamw', + 'nesterov', + ], } frameworks = [ - 'pytorch', - 'jax', + 'pytorch', + 'jax', ] -baseline_path = "reference_algorithms/paper_baselines" +baseline_path = 'reference_algorithms/paper_baselines' named_parameters = [] for f in frameworks: for b in baselines[f]: named_parameters.append( - dict( - testcase_name=f'{b}_{f}', - workload='mnist', - framework=f'{f}', - submission_path=f'{baseline_path}/{b}/{f}/submission.py', - tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json') + dict( + testcase_name=f'{b}_{f}', + workload='mnist', + framework=f'{f}', + submission_path=f'{baseline_path}/{b}/{f}/submission.py', + tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json', + ) ) @@ -66,31 +66,31 @@ class BaselineTest(parameterized.TestCase): """Tests for reference submissions.""" @parameterized.named_parameters(*named_parameters) - def test_baseline_submission(self, - workload, - framework, - submission_path, - tuning_search_space): + def test_baseline_submission( + self, workload, framework, submission_path, tuning_search_space + ): FLAGS.framework = framework workload_metadata = copy.deepcopy(workloads.WORKLOADS[workload]) workload_metadata['workload_path'] = os.path.join( - workloads.BASE_WORKLOADS_DIR, - workload_metadata['workload_path'] + '_' + framework, - 'workload.py') + workloads.BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + framework, + 'workload.py', + ) workload_obj = workloads.import_workload( - workload_path=workload_metadata['workload_path'], - workload_class_name=workload_metadata['workload_class_name'], - workload_init_kwargs={}) + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}, + ) score = submission_runner.score_submission_on_workload( - workload_obj, - workload, - submission_path, - data_dir='~/tensorflow_datasets', # The default in TFDS. - tuning_ruleset='external', - tuning_search_space=tuning_search_space, - num_tuning_trials=1, - profiler=PassThroughProfiler(), - max_global_steps=MAX_GLOBAL_STEPS, + workload_obj, + workload, + submission_path, + data_dir='~/tensorflow_datasets', # The default in TFDS. + tuning_ruleset='external', + tuning_search_space=tuning_search_space, + num_tuning_trials=1, + profiler=PassThroughProfiler(), + max_global_steps=MAX_GLOBAL_STEPS, ) logging.info(score) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py new file mode 100644 index 000000000..28e506400 --- /dev/null +++ b/tests/test_jax_utils.py @@ -0,0 +1,215 @@ +""" +Test algoperf.jax_utils.Dropout by comparing to flax.linen.Dropout +Run it as: pytest +""" + +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized +from jax.tree_util import tree_leaves, tree_map, tree_structure + +from algoperf.jax_utils import Dropout + +SEED = 1996 +DEFAULT_DROPOUT = 0.5 + + +def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8): + """ + A custom function to check if two PyTrees are equal, handling floats with + a tolerance. + """ + if tree_structure(a) != tree_structure(b): + return False + + def leaf_comparator(x, y): + # Use allclose for floating-point JAX arrays + if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating): + return jnp.allclose(x, y, rtol=rtol, atol=atol) + # Use standard equality for everything else + else: + return x == y + + comparison_tree = tree_map(leaf_comparator, a, b) + all_equal = all(tree_leaves(comparison_tree)) + + return all_equal + + +class LegacyDropoutModel(nn.Module): + dropout_rate: float = DEFAULT_DROPOUT + + @nn.compact + def __call__(self, x, train): + return nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + + +class DropoutModel(nn.Module): + @nn.compact + def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT): + return Dropout(rate=dropout_rate, deterministic=not train)( + x, rate=dropout_rate + ) + + +class DropoutTest(parameterized.TestCase): + @parameterized.named_parameters( + dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'), + dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'), + dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'), + dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'), + ) + def test_forward(self, dropout_rate, mode): + """Compare forward pass of Dropout layer to flax.linen.Dropout in train and + eval mode. + """ + + # initialize models + rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) + fake_batch = jnp.ones((10,)) + orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) + cust_model = DropoutModel() + + initial_variables_original = orig_model.init( + {'params': rng}, fake_batch, train=False + ) + initial_variables_custom = cust_model.init( + {'params': rng}, fake_batch, train=False + ) + + assert pytrees_are_equal( + initial_variables_original, initial_variables_custom, rtol=1e-6 + ) + + # forward pass + x = jnp.ones((10,)) + + train = mode == 'train' + y1 = orig_model.apply( + initial_variables_original, x, train=train, rngs={'dropout': dropout_rng} + ) + y2 = cust_model.apply( + initial_variables_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, + ) + + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + @parameterized.named_parameters( + dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'), + dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'), + dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'), + dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'), + ) + def test_dropout_update(self, dropout_rate, mode): + """Call forward pass of Dropout layer with two different dropout rates + and check that the output matches to flax.linen.Dropout in train and + eval mode. + """ + # init model + rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) + fake_batch = jnp.ones((10,)) + orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) + cust_model = DropoutModel() + + initial_variables_original = orig_model.init( + {'params': rng}, fake_batch, train=False + ) + + initial_variables_custom = cust_model.init( + {'params': rng}, fake_batch, train=False + ) + + assert pytrees_are_equal( + initial_variables_original, initial_variables_custom, rtol=1e-6 + ) + + # forward pass + x = jnp.ones((10,)) + + train = mode == 'train' + y1 = orig_model.apply( + initial_variables_original, x, train=train, rngs={'dropout': dropout_rng} + ) + + _ = cust_model.apply( + initial_variables_custom, + x, + train=train, + dropout_rate=0.9, + rngs={'dropout': dropout_rng}, + ) + + y2 = cust_model.apply( + initial_variables_custom, + x, + train=train, + dropout_rate=dropout_rate, + rngs={'dropout': dropout_rng}, + ) + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + @parameterized.named_parameters( + dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'), + dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'), + dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'), + dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'), + ) + def test_jitted_updates(self, dropout_rate, mode): + """Compare jitted updates with dropout.""" + + # initialize models + rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2) + fake_batch = jnp.ones((10,)) + orig_model = LegacyDropoutModel(dropout_rate=dropout_rate) + cust_model = DropoutModel() + + initial_variables_original = orig_model.init( + {'params': rng}, fake_batch, train=False + ) + initial_variables_custom = cust_model.init( + {'params': rng}, fake_batch, train=False + ) + + assert pytrees_are_equal( + initial_variables_original, initial_variables_custom, rtol=1e-6 + ) + + # forward pass + x = jnp.ones((10,)) + + train = mode == 'train' + jitted_original_apply = jax.jit( + partial(orig_model.apply), static_argnames=['train'] + ) + jitted_custom_apply = jax.jit( + partial(cust_model.apply), static_argnames=['train'] + ) + + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y1 = jitted_original_apply( + initial_variables_original, + x, + train=train, + rngs={'dropout': dropout_rng}, + ) + + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: + y2 = jitted_custom_apply( + initial_variables_custom, + x, + train=train, + dropout_rate=d, + rngs={'dropout': dropout_rng}, + ) + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) + + +if __name__ == '__main__': + absltest.main() diff --git a/tests/test_num_params.py b/tests/test_num_params.py index b0633025e..9361f4c72 100644 --- a/tests/test_num_params.py +++ b/tests/test_num_params.py @@ -5,48 +5,59 @@ import pytest import torch -from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ - DlrmSmall as JaxDlrmSmall -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ - DlrmSmall as PyTorchDlrmSmall -from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNet18 as JaxResNet_c10 -from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNet50 as JaxResNet -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet18 as PyTorchResNet_c10 -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet50 as PyTorchResNet +from algoperf.workloads.criteo1tb.criteo1tb_jax.models import ( + DlrmSmall as JaxDlrmSmall, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import ( + DlrmSmall as PyTorchDlrmSmall, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ( + ResNet18 as JaxResNet_c10, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ( + ResNet50 as JaxResNet, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + resnet18 as PyTorchResNet_c10, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import ( + resnet50 as PyTorchResNet, +) from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT -from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ - ViT as PyTorchViT -from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ - Conformer as JaxConformer -from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ - ConformerConfig as JaxConformerConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerConfig as PytorchConformerConfig -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - ConformerEncoderDecoder as PytorchConformer +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ( + ViT as PyTorchViT, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ( + Conformer as JaxConformer, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ( + ConformerConfig as JaxConformerConfig, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + ConformerConfig as PytorchConformerConfig, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + ConformerEncoderDecoder as PytorchConformer, +) from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP -from algoperf.workloads.mnist.mnist_pytorch.workload import \ - _Model as PyTorchMLP +from algoperf.workloads.mnist.mnist_pytorch.workload import _Model as PyTorchMLP from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig -from algoperf.workloads.wmt.wmt_pytorch.models import \ - Transformer as PyTorchTransformer +from algoperf.workloads.wmt.wmt_pytorch.models import ( + Transformer as PyTorchTransformer, +) WORKLOADS = [ - 'mnist', - 'cifar', - 'criteo1tb', - 'imagenet_resnet', - 'imagenet_vit', - 'wmt', - 'ogbg', - 'librispeech_conformer', + 'mnist', + 'cifar', + 'criteo1tb', + 'imagenet_resnet', + 'imagenet_vit', + 'wmt', + 'ogbg', + 'librispeech_conformer', ] @@ -56,7 +67,8 @@ def test_matching_num_params(workload): # Count parameters of both models. num_jax_params = sum(x.size for x in jax.tree_util.tree_leaves(jax_model)) num_pytorch_params = sum( - p.numel() for p in pytorch_model.parameters() if p.requires_grad) + p.numel() for p in pytorch_model.parameters() if p.requires_grad + ) assert num_jax_params == num_pytorch_params @@ -72,8 +84,9 @@ def get_models(workload): # Init Jax model. input_shape = (1, 32, 32, 3) model_init = jax.jit(JaxResNet_c10(num_classes=10, dtype=jnp.float32).init) - jax_model = model_init(init_rngs, jnp.ones(input_shape, - jnp.float32))["params"] + jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32))[ + 'params' + ] # Init PyTorch model. pytorch_model = PyTorchResNet_c10(num_classes=10) @@ -85,35 +98,38 @@ def get_models(workload): vocab_size = 32 * 128 * 1024 input_shape = (1, 39) model_init = JaxDlrmSmall( - vocab_size=vocab_size, - num_dense_features=13, - mlp_bottom_dims=mlp_bottom_dims, - mlp_top_dims=mlp_top_dims, - embed_dim=embed_dim).init - jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32), - False)['params'] + vocab_size=vocab_size, + num_dense_features=13, + mlp_bottom_dims=mlp_bottom_dims, + mlp_top_dims=mlp_top_dims, + embed_dim=embed_dim, + ).init + jax_model = model_init( + init_rngs, jnp.ones(input_shape, jnp.float32), False + )['params'] # Init PyTorch model. pytorch_model = PyTorchDlrmSmall( - vocab_size=vocab_size, - num_dense_features=13, - mlp_bottom_dims=mlp_bottom_dims, - mlp_top_dims=mlp_top_dims, - embed_dim=embed_dim) + vocab_size=vocab_size, + num_dense_features=13, + mlp_bottom_dims=mlp_bottom_dims, + mlp_top_dims=mlp_top_dims, + embed_dim=embed_dim, + ) elif workload == 'imagenet_resnet': # Init Jax model. input_shape = (1, 224, 224, 3) - jax_model = JaxResNet( - num_classes=1000, - dtype=jnp.float32).init(init_rngs, jnp.ones(input_shape, - jnp.float32))['params'] + jax_model = JaxResNet(num_classes=1000, dtype=jnp.float32).init( + init_rngs, jnp.ones(input_shape, jnp.float32) + )['params'] # Init PyTorch model. pytorch_model = PyTorchResNet() elif workload == 'imagenet_vit': # Init Jax model. input_shape = (1, 224, 224, 3) jax_model = JaxViT(num_classes=1000).init( - init_rngs, jnp.ones(input_shape, jnp.float32))['params'] + init_rngs, jnp.ones(input_shape, jnp.float32) + )['params'] # Init PyTorch model. pytorch_model = PyTorchViT() elif workload == 'librispeech_conformer': @@ -123,8 +139,9 @@ def get_models(workload): # Init Jax model input_shape = [(320000,), (320000,)] fake_input_batch = [jnp.zeros((2, *x), jnp.float32) for x in input_shape] - jax_model = jax_model.init( - init_rngs, train=False, *fake_input_batch)["params"] + jax_model = jax_model.init(init_rngs, train=False, *fake_input_batch)[ + 'params' + ] # Run model once to initialize lazy layers wave = torch.randn(2, 320000) @@ -136,23 +153,26 @@ def get_models(workload): input_shape = (16, 256) target_shape = (16, 256) jax_model = JaxTransformer(TransformerConfig).init( - init_rngs, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32))['params'] + init_rngs, + jnp.ones(input_shape, jnp.float32), + jnp.ones(target_shape, jnp.float32), + )['params'] # Init PyTorch model. pytorch_model = PyTorchTransformer() elif workload == 'ogbg': # Init Jax model. fake_batch = jraph.GraphsTuple( - n_node=jnp.asarray([1]), - n_edge=jnp.asarray([1]), - nodes=jnp.ones((1, 9)), - edges=jnp.ones((1, 3)), - globals=jnp.zeros((1, 128)), - senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) + n_node=jnp.asarray([1]), + n_edge=jnp.asarray([1]), + nodes=jnp.ones((1, 9)), + edges=jnp.ones((1, 3)), + globals=jnp.zeros((1, 128)), + senders=jnp.asarray([0]), + receivers=jnp.asarray([0]), + ) jax_model = JaxGNN(num_outputs=128).init( - init_rngs, fake_batch, train=False)['params'] + init_rngs, fake_batch, train=False + )['params'] # Init PyTorch model. pytorch_model = PyTorchGNN(num_outputs=128) else: diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index df4c798d8..2243ce52e 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -5,42 +5,79 @@ import pytest from flax.core import FrozenDict -# isort: skip_file -# pylint:disable=line-too-long -from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload -# pylint:enable=line-too-long +from algoperf.workloads.cifar.cifar_jax.workload import ( + CifarWorkload as JaxCifarWorkload, +) +from algoperf.workloads.cifar.cifar_pytorch.workload import ( + CifarWorkload as PyTorchCifarWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload, +) +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIWorkload as JaxFastMRIWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIWorkload as PyTorchFastMRIWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload as JaxImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload as PyTorchImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitWorkload as JaxImagenetViTWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitWorkload as PyTorchImagenetViTWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.mnist.mnist_jax.workload import ( + MnistWorkload as JaxMnistWorkload, +) +from algoperf.workloads.mnist.mnist_pytorch.workload import ( + MnistWorkload as PyTorchMnistWorkload, +) +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgWorkload as JaxOgbgWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgWorkload as PyTorchOgbgWorkload, +) +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkload as JaxWmtWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkload as PyTorchWmtWorkload, +) WORKLOADS = [ - 'cifar', - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - # TODO: make tests work for these. - # 'librispeech_conformer', - # 'librispeech_deepspeech', - 'mnist', - 'ogbg', - 'wmt', + 'cifar', + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + # TODO: make tests work for these. + # 'librispeech_conformer', + # 'librispeech_deepspeech', + 'mnist', + 'ogbg', + 'wmt', ] @@ -56,9 +93,11 @@ def test_param_shapes(workload): if isinstance(jax_workload_param_shapes, dict): jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload_param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze() + ) pytorch_param_shapes = jax.tree_util.tree_leaves( - pytorch_workload.param_shapes) + pytorch_workload.param_shapes + ) if workload == 'wmt': # The PyTorch transformer for WMT is implemented with fused linear layers # for the projection of QKV inside of the MultiheadAttention module. @@ -74,8 +113,9 @@ def test_param_shapes(workload): # Check if total number of params deduced from shapes match. num_jax_params = 0 num_pytorch_params = 0 - for jax_shape, pytorch_shape in zip_longest(jax_param_shapes, - pytorch_param_shapes): + for jax_shape, pytorch_shape in zip_longest( + jax_param_shapes, pytorch_param_shapes + ): if jax_shape is not None: num_jax_params += np.prod(jax_shape.shape_tuple) if pytorch_shape is not None: diff --git a/tests/test_param_types.py b/tests/test_param_types.py index d3722ae86..9f14f7dd8 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -1,44 +1,80 @@ import jax import pytest - from absl import logging -from algoperf import spec -# isort: skip_file -# pylint:disable=line-too-long -from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload -# pylint:enable=line-too-long +from algoperf import spec +from algoperf.workloads.cifar.cifar_jax.workload import ( + CifarWorkload as JaxCifarWorkload, +) +from algoperf.workloads.cifar.cifar_pytorch.workload import ( + CifarWorkload as PyTorchCifarWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import ( + Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload, +) +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import ( + Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload, +) +from algoperf.workloads.fastmri.fastmri_jax.workload import ( + FastMRIWorkload as JaxFastMRIWorkload, +) +from algoperf.workloads.fastmri.fastmri_pytorch.workload import ( + FastMRIWorkload as PyTorchFastMRIWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload as JaxImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ( + ImagenetResNetWorkload as PyTorchImagenetResNetWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ( + ImagenetVitWorkload as JaxImagenetViTWorkload, +) +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ( + ImagenetVitWorkload as PyTorchImagenetViTWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import ( + LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import ( + LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload, +) +from algoperf.workloads.mnist.mnist_jax.workload import ( + MnistWorkload as JaxMnistWorkload, +) +from algoperf.workloads.mnist.mnist_pytorch.workload import ( + MnistWorkload as PyTorchMnistWorkload, +) +from algoperf.workloads.ogbg.ogbg_jax.workload import ( + OgbgWorkload as JaxOgbgWorkload, +) +from algoperf.workloads.ogbg.ogbg_pytorch.workload import ( + OgbgWorkload as PyTorchOgbgWorkload, +) +from algoperf.workloads.wmt.wmt_jax.workload import ( + WmtWorkload as JaxWmtWorkload, +) +from algoperf.workloads.wmt.wmt_pytorch.workload import ( + WmtWorkload as PyTorchWmtWorkload, +) WORKLOADS = [ - 'cifar', - 'criteo1tb', - 'fastmri', - 'imagenet_resnet', - 'imagenet_vit', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'mnist', - 'ogbg', - 'wmt', + 'cifar', + 'criteo1tb', + 'fastmri', + 'imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'mnist', + 'ogbg', + 'wmt', ] @@ -66,40 +102,32 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): # Sometimes one framework will implement QKV as a single parameter, so we need # to make sure there are the same number of QKV params as Q, K, V. num_qkv = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0), + 'pytorch': pytorch_param_types_dict.get( + spec.ParameterType.ATTENTION_QKV, 0 + ), } num_kv = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0), } num_q = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0), } num_k = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0), } num_v = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), + 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0), } num_bias = { - 'jax': - jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), - 'pytorch': - pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), + 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0), + 'pytorch': pytorch_param_types_dict.get( + spec.ParameterType.ATTENTION_BIAS, 0 + ), } qkv_match = num_qkv['jax'] == num_qkv['pytorch'] kv_match = num_kv['jax'] == num_kv['pytorch'] @@ -108,24 +136,33 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict): v_match = num_v['jax'] == num_v['pytorch'] bias_match = num_bias['jax'] == num_bias['pytorch'] qkv_match = ( - qkv_match and kv_match and q_match and k_match and v_match and bias_match) + qkv_match and kv_match and q_match and k_match and v_match and bias_match + ) # We subtract 2 * num_qkv from the number of biases because there are 2 # missing for each of q, k, v. - jax_qkv_match = ( - num_q['pytorch'] == num_k['pytorch'] == num_v['pytorch'] == num_qkv['jax'] - and (num_qkv['jax'] != 0 and - (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax'])) - pytorch_qkv_match = ( - num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv['pytorch'] and - (num_qkv['pytorch'] != 0 and - (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch'])) + jax_qkv_match = num_q['pytorch'] == num_k['pytorch'] == num_v[ + 'pytorch' + ] == num_qkv['jax'] and ( + num_qkv['jax'] != 0 + and (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax'] + ) + pytorch_qkv_match = num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv[ + 'pytorch' + ] and ( + num_qkv['pytorch'] != 0 + and (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch'] + ) pytorch_kv_match = ( - num_q['jax'] == num_k['jax'] == num_v['jax'] == - num_qkv['pytorch'] + num_kv['pytorch'] and - num_q['pytorch'] == num_kv['pytorch']) + num_q['jax'] + == num_k['jax'] + == num_v['jax'] + == num_qkv['pytorch'] + num_kv['pytorch'] + and num_q['pytorch'] == num_kv['pytorch'] + ) qkv_match = ( - qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match) + qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match + ) return qkv_match @@ -137,7 +174,8 @@ def test_param_types(workload_name): # Compare number of parameter tensors of both models. jax_param_types = jax.tree_util.tree_leaves(jax_workload.model_params_types) pytorch_param_types = jax.tree_util.tree_leaves( - pytorch_workload.model_params_types) + pytorch_workload.model_params_types + ) jax_param_types_dict = count_param_types(jax_param_types) pytorch_param_types_dict = count_param_types(pytorch_param_types) @@ -161,30 +199,33 @@ def test_param_types(workload_name): # Check if total number of each type match. attention_keys = { - spec.ParameterType.ATTENTION_QKV, - spec.ParameterType.ATTENTION_KV, - spec.ParameterType.ATTENTION_Q, - spec.ParameterType.ATTENTION_K, - spec.ParameterType.ATTENTION_V, - spec.ParameterType.ATTENTION_BIAS, + spec.ParameterType.ATTENTION_QKV, + spec.ParameterType.ATTENTION_KV, + spec.ParameterType.ATTENTION_Q, + spec.ParameterType.ATTENTION_K, + spec.ParameterType.ATTENTION_V, + spec.ParameterType.ATTENTION_BIAS, } non_attention_keys = set(jax_param_types_dict.keys()).union( - set(pytorch_param_types_dict.keys())) + set(pytorch_param_types_dict.keys()) + ) non_attention_keys -= attention_keys mismatches = '' - mismatches += _count_mismatches(jax_param_types_dict, - pytorch_param_types_dict, - non_attention_keys) - qkv_match = _check_attention_qkv_match(jax_param_types_dict, - pytorch_param_types_dict) + mismatches += _count_mismatches( + jax_param_types_dict, pytorch_param_types_dict, non_attention_keys + ) + qkv_match = _check_attention_qkv_match( + jax_param_types_dict, pytorch_param_types_dict + ) if not qkv_match: - mismatches += _count_mismatches(jax_param_types_dict, - pytorch_param_types_dict, - attention_keys) + mismatches += _count_mismatches( + jax_param_types_dict, pytorch_param_types_dict, attention_keys + ) if mismatches: raise ValueError( - f'On workload {workload_name}, count mismatch: {mismatches}') + f'On workload {workload_name}, count mismatch: {mismatches}' + ) def get_workload(workload_name): diff --git a/tests/test_ssim.py b/tests/test_ssim.py index 920556964..dcb3f25e0 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -3,20 +3,20 @@ import os from typing import Tuple -from absl.testing import absltest -from absl.testing import parameterized import jax.numpy as jnp import numpy as np import torch +from absl.testing import absltest, parameterized from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.fastmri.fastmri_jax.ssim import \ - _uniform_filter as _jax_uniform_filter +from algoperf.workloads.fastmri.fastmri_jax.ssim import ( + _uniform_filter as _jax_uniform_filter, +) from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim -from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ - _uniform_filter as _pytorch_uniform_filter -from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ - ssim as pytorch_ssim +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ( + _uniform_filter as _pytorch_uniform_filter, +) +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim as pytorch_ssim # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' @@ -31,7 +31,7 @@ def _create_fake_im(height: int, width: int) -> Tuple[jnp.array, torch.Tensor]: def _create_fake_batch( - batch_size: int, height: int, width: int + batch_size: int, height: int, width: int ) -> Tuple[Tuple[jnp.array, jnp.array], Tuple[torch.Tensor, torch.Tensor]]: logits = np.random.randn(batch_size, height, width) targets = np.random.randn(batch_size, height, width) @@ -47,9 +47,9 @@ class SSIMTest(parameterized.TestCase): and PyTorch.""" @parameterized.named_parameters( - dict(testcase_name='fastmri_im', height=320, width=320), - dict(testcase_name='uneven_even_im', height=31, width=16), - dict(testcase_name='even_uneven_im', height=42, width=53), + dict(testcase_name='fastmri_im', height=320, width=320), + dict(testcase_name='uneven_even_im', height=31, width=16), + dict(testcase_name='even_uneven_im', height=42, width=53), ) def test_uniform_filter(self, height: int, width: int) -> None: jax_im, pytorch_im = _create_fake_im(height, width) @@ -58,12 +58,9 @@ def test_uniform_filter(self, height: int, width: int) -> None: assert np.allclose(jax_result, torch_result, atol=1e-6) @parameterized.named_parameters( - dict( - testcase_name='fastmri_batch', batch_size=256, height=320, width=320), - dict( - testcase_name='uneven_even_batch', batch_size=8, height=31, width=16), - dict( - testcase_name='even_uneven_batch', batch_size=8, height=42, width=53), + dict(testcase_name='fastmri_batch', batch_size=256, height=320, width=320), + dict(testcase_name='uneven_even_batch', batch_size=8, height=31, width=16), + dict(testcase_name='even_uneven_batch', batch_size=8, height=42, width=53), ) def test_ssim(self, batch_size: int, height: int, width: int) -> None: jax_inputs, pytorch_inputs = _create_fake_batch(batch_size, height, width) @@ -71,9 +68,8 @@ def test_ssim(self, batch_size: int, height: int, width: int) -> None: pytorch_ssim_result = pytorch_ssim(*pytorch_inputs) self.assertEqual(jax_ssim_result.shape, pytorch_ssim_result.shape) assert np.allclose( - jax_ssim_result.sum().item(), - pytorch_ssim_result.sum().item(), - atol=1e-6) + jax_ssim_result.sum().item(), pytorch_ssim_result.sum().item(), atol=1e-6 + ) if __name__ == '__main__': diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index cea589202..8acfc855a 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -3,28 +3,26 @@ Run it as: python3 test_traindiffs.py """ + import pickle import subprocess -from subprocess import DEVNULL -from subprocess import run -from subprocess import STDOUT +from subprocess import DEVNULL, STDOUT, run from absl import flags -from absl.testing import absltest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from numpy import allclose FLAGS = flags.FLAGS WORKLOADS = [ - 'imagenet_resnet', - 'imagenet_vit', - 'wmt', - 'librispeech_conformer', - 'librispeech_deepspeech', - 'fastmri', - 'ogbg', - 'criteo1tb' + 'imagenet_resnet', + 'imagenet_vit', + 'wmt', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'fastmri', + 'ogbg', + 'criteo1tb', ] GLOBAL_BATCH_SIZE = 16 NUM_TRAIN_STEPS = 10 @@ -35,7 +33,6 @@ class ModelDiffTest(parameterized.TestCase): - @parameterized.named_parameters(*named_parameters) def test_workload(self, workload): # pylint: disable=line-too-long, unnecessary-lambda-assignment @@ -50,24 +47,26 @@ def test_workload(self, workload): pytorch_logs_path = '/tmp/pyt_log.pkl' try: run( - f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}' - f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', - shell=True, - stdout=DEVNULL, - stderr=STDOUT, - check=True) + f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}' + f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', + shell=True, + stdout=DEVNULL, + stderr=STDOUT, + check=True, + ) except subprocess.CalledProcessError as e: - print("Error:", e) + print('Error:', e) try: run( - f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}' - f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', - shell=True, - stdout=DEVNULL, - stderr=STDOUT, - check=True) + f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}' + f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', + shell=True, + stdout=DEVNULL, + stderr=STDOUT, + check=True, + ) except subprocess.CalledProcessError as e: - print("Error:", e) + print('Error:', e) with open(jax_logs_path, 'rb') as f: jax_results = pickle.load(f) with open(pytorch_logs_path, 'rb') as f: @@ -75,19 +74,25 @@ def test_workload(self, workload): # PRINT RESULTS eval_metric_key = next( - iter( - filter(lambda k: 'train' in k and 'loss' in k, - jax_results['eval_results'][0]))) + iter( + filter( + lambda k: 'train' in k and 'loss' in k, jax_results['eval_results'][0] + ) + ) + ) header = [ - 'Iter', - 'Eval (jax)', - 'Eval (torch)', - 'Grad Norm (jax)', - 'Grad Norm (torch)', - 'Train Loss (jax)', - 'Train Loss (torch)', + 'Iter', + 'Eval (jax)', + 'Eval (torch)', + 'Grad Norm (jax)', + 'Grad Norm (torch)', + 'Train Loss (jax)', + 'Train Loss (torch)', ] - fmt = lambda l: '|' + '|'.join(map(lambda x: f'{x:^20s}', l)) + '|' + + def fmt(line): + return '|' + '|'.join(map(lambda x: f'{x:^20s}', line)) + '|' + header = fmt(header) pad = (len(header) - len((name))) // 2 print('=' * pad, name, '=' * (len(header) - len(name) - pad), sep='') @@ -97,33 +102,41 @@ def test_workload(self, workload): for i in range(NUM_TRAIN_STEPS): rtol = 1 - row = map(lambda x: str(round(x, 5)), - [ - jax_results['eval_results'][i][eval_metric_key], - pytorch_results['eval_results'][i][eval_metric_key], - jax_results['scalars'][i]['grad_norm'], - pytorch_results['scalars'][i]['grad_norm'], - jax_results['scalars'][i]['loss'], - pytorch_results['scalars'][i]['loss'], - ]) + row = map( + lambda x: str(round(x, 5)), + [ + jax_results['eval_results'][i][eval_metric_key], + pytorch_results['eval_results'][i][eval_metric_key], + jax_results['scalars'][i]['grad_norm'], + pytorch_results['scalars'][i]['grad_norm'], + jax_results['scalars'][i]['loss'], + pytorch_results['scalars'][i]['loss'], + ], + ) print(fmt([f'{i}', *row])) print('=' * len(header)) self.assertTrue( # eval_results - allclose( - jax_results['eval_results'][i][eval_metric_key], - pytorch_results['eval_results'][i][eval_metric_key], - rtol=rtol)) + allclose( + jax_results['eval_results'][i][eval_metric_key], + pytorch_results['eval_results'][i][eval_metric_key], + rtol=rtol, + ) + ) self.assertTrue( # grad_norms - allclose( - jax_results['scalars'][i]['grad_norm'], - pytorch_results['scalars'][i]['grad_norm'], - rtol=rtol)) + allclose( + jax_results['scalars'][i]['grad_norm'], + pytorch_results['scalars'][i]['grad_norm'], + rtol=rtol, + ) + ) self.assertTrue( # loss - allclose( - jax_results['scalars'][i]['loss'], - pytorch_results['scalars'][i]['loss'], - rtol=rtol)) + allclose( + jax_results['scalars'][i]['loss'], + pytorch_results['scalars'][i]['loss'], + rtol=rtol, + ) + ) if __name__ == '__main__': diff --git a/tests/test_version.py b/tests/test_version.py index d1bfbd18f..69384953a 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -6,10 +6,10 @@ def test_version_attribute(): """Check whether __version__ exists and is a valid string.""" - assert hasattr(algoperf, "__version__") + assert hasattr(algoperf, '__version__') version = algoperf.__version__ assert isinstance(version, str) - version_elements = version.split(".") + version_elements = version.split('.') print(version_elements) # Only check the first two elements, i.e. major, minor # (patch is not checked as it is not required). diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index d44234927..60a1af2f2 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -1,12 +1,13 @@ """Tests for imagenet_resnet/imagenet_jax/workload.py.""" -from absl.testing import absltest import jax import jax.numpy as jnp +from absl.testing import absltest from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ - ImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ( + ImagenetResNetWorkload, +) def _pytree_total_diff(pytree_a, pytree_b): @@ -32,42 +33,48 @@ def test_forward_pass(self): # this function because we call it with a different combination of those two # args each time. Can't call with kwargs. pmapped_model_fn = jax.pmap( - workload.model_fn, - axis_name='batch', - in_axes=(0, 0, 0, None, None, None), - static_broadcasted_argnums=(3, 5)) + workload.model_fn, + axis_name='batch', + in_axes=(0, 0, 0, None, None, None), + static_broadcasted_argnums=(3, 5), + ) logits, updated_batch_stats = pmapped_model_fn( - model_params, - {'inputs': first_input_batch}, - batch_stats, - spec.ForwardPassMode.TRAIN, - rng, - True) + model_params, + {'inputs': first_input_batch}, + batch_stats, + spec.ForwardPassMode.TRAIN, + rng, + True, + ) self.assertEqual(logits.shape, expected_logits_shape) # Test that batch stats are updated. self.assertNotEqual( - _pytree_total_diff(batch_stats, updated_batch_stats), 0.0) + _pytree_total_diff(batch_stats, updated_batch_stats), 0.0 + ) second_input_batch = jax.random.normal(data_rngs[1], shape=input_shape) # Test that batch stats are not updated when we say so. _, same_batch_stats = pmapped_model_fn( - model_params, - {'inputs': second_input_batch}, - updated_batch_stats, - spec.ForwardPassMode.TRAIN, - rng, - False) + model_params, + {'inputs': second_input_batch}, + updated_batch_stats, + spec.ForwardPassMode.TRAIN, + rng, + False, + ) self.assertEqual( - _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0) + _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0 + ) # Test eval model. logits, _ = pmapped_model_fn( - model_params, - {'inputs': second_input_batch}, - batch_stats, - spec.ForwardPassMode.EVAL, - rng, - False) + model_params, + {'inputs': second_input_batch}, + batch_stats, + spec.ForwardPassMode.EVAL, + rng, + False, + ) self.assertEqual(logits.shape, expected_logits_shape)