diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 33a14f83c..38867b369 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -216,7 +216,7 @@ pylint tests ## Unit and integration tests We run unit tests and integration tests as part of the of github actions as well. -You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/development_algorithms/`. +You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/target_setting_algorithms/`. ## Regression tests We also have regression tests available in [.github/workflows/regression_tests.yml](https://github.com/mlcommons/algorithmic-efficiency/tree/main/.github/workflows/regression_tests.yml) that can be run semi-automatically. diff --git a/reference_algorithms/development_algorithms/README.md b/reference_algorithms/development_algorithms/README.md deleted file mode 100644 index 12b1b1f8e..000000000 --- a/reference_algorithms/development_algorithms/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Development Algorithms - -These are various algorithms used during the testing and development of the codebase. - -These are not valid submissions, because they use a different hyperparameter settings and algorithms per workload. diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py deleted file mode 100644 index 4dea0c321..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Training algorithm track submission functions for Criteo1TB DLRM-Small.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 524_288 // 2 - - -def create_learning_rate_fn(workload: spec.Workload, - hparams: spec.Hyperparameters): - """Create learning rate schedule.""" - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hparams.learning_rate, - transition_steps=hparams.warmup_steps) - cosine_fn = optax.cosine_decay_schedule( - init_value=hparams.learning_rate, - decay_steps=(workload.step_hint - hparams.warmup_steps)) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[hparams.warmup_steps]) - return schedule_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - learning_rate_fn = create_learning_rate_fn(workload, hyperparameters) - opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=learning_rate_fn, - b1=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=False) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - return loss, new_model_state - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (loss, new_model_state), grad = grad_fn(current_param_container) - (loss, grad) = lax.pmean((loss, grad), axis_name='batch') - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - # del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py deleted file mode 100644 index d9d9c29b5..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - batch_sizes = {'criteo1tb': 524_288} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del rng - del model_state - - base_lr = hyperparameters.learning_rate - - optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=base_lr, - weight_decay=hyperparameters.weight_decay, - betas=(hyperparameters.beta1, 0.999)), - } - - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-12, - end_factor=1., - total_iters=hyperparameters.warmup_steps) - - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], - T_max=(workload.step_hint - hyperparameters.warmup_steps), - ) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_steps]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=None, - aux_dropout_rate=None, - update_batch_norm=False) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json deleted file mode 100644 index a30292bdb..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "learning_rate": { - "feasible_points": [ - 0.0065686501947063445 - ] - }, - "beta1": { - "feasible_points": [ - 0.8743797750166902 - ] - }, - "beta2": { - "feasible_points": [ - 0.9980006182116233 - ] - }, - "warmup_steps": { - "feasible_points": [ - 800 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.5301171352729387e-5 - ] - } -} diff --git a/reference_algorithms/development_algorithms/fastmri/__init__.py b/reference_algorithms/development_algorithms/fastmri/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py b/reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py b/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py deleted file mode 100644 index 73b020112..000000000 --- a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Training algorithm track submission functions for FastMRI in Jax.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 64 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - max_num_train_steps = 500 * steps_per_epoch - decay_epoch_period = hparams.lr_step_size * steps_per_epoch - decay_events = range(decay_epoch_period, - max_num_train_steps, - decay_epoch_period) - schedule_fn = optax.piecewise_constant_schedule( - init_value=hparams.learning_rate, - boundaries_and_scales={t: hparams.lr_gamma for t in decay_events}) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.rmsprop( - learning_rate=learning_rate_fn, - decay=0.99) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1, 4)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss - - grad_fn = jax.grad(_loss_fn) - grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del model_state - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, None - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py b/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py b/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py deleted file mode 100644 index 38828d4c3..000000000 --- a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Training algorithm track submission functions for FastMRI.""" - -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import StepLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'fastmri': 8} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - base_lr = hyperparameters.learning_rate * get_batch_size('fastmri') - optimizer_state = { - 'optimizer': - torch.optim.RMSprop( - model_params.parameters(), - lr=base_lr, - weight_decay=hyperparameters.l2), - } - - optimizer_state['scheduler'] = StepLR( - optimizer_state['optimizer'], - step_size=hyperparameters.lr_step_size, - gamma=hyperparameters.lr_gamma) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - outputs_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=outputs_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - steps_per_epoch = workload.num_train_examples // get_batch_size('fastmri') - if (global_step + 1) % steps_per_epoch == 0: - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json deleted file mode 100644 index 01e4e00c2..000000000 --- a/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.001]}, - "num_epochs": {"feasible_points": [50]}, - "l2": {"feasible_points": [0.0]}, - "lr_step_size": {"feasible_points": [40]}, - "lr_gamma": {"feasible_points": [0.1]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py deleted file mode 100644 index 9c686d524..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * \ - get_batch_size('imagenet_resnet') / 256. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.sgd( - nesterov=True, - momentum=hyperparameters.momentum, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - variables = {'params': params, **model_state} - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(variables['params']) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py deleted file mode 100644 index 694e924f7..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - batch_size = get_batch_size('imagenet_resnet') - base_lr = hyperparameters.learning_rate * batch_size / 256. - optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=base_lr, - momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2, - nesterov=True), - } - - steps_per_epoch = workload.num_train_examples // batch_size - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_epochs * steps_per_epoch) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json deleted file mode 100644 index da969416b..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.1]}, - "warmup_epochs": {"feasible_points": [5]}, - "num_epochs": {"feasible_points": [100]}, - "l2": {"feasible_points": [1e-4]}, - "momentum": {"feasible_points": [0.9]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/imagenet_vit/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py deleted file mode 100644 index 4d65d9675..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * \ - get_batch_size('imagenet_vit') / 1024. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_vit') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.adam( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py deleted file mode 100644 index eee2a01db..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - batch_size = get_batch_size('imagenet_vit') - base_lr = hyperparameters.learning_rate * batch_size / 1024. - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=base_lr, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon), - } - - steps_per_epoch = workload.num_train_examples // batch_size - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_epochs * steps_per_epoch) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json deleted file mode 100644 index e6cf84733..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "learning_rate": {"feasible_points": [1e-3]}, - "beta1": {"feasible_points": [0.9]}, - "beta2": {"feasible_points": [0.999]}, - "epsilon": {"feasible_points": [1e-8]}, - "num_epochs": {"feasible_points": [100]}, - "warmup_epochs": {"feasible_points": [5]}, - "l2": {"feasible_points": [1e-1]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py deleted file mode 100644 index ea314b820..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -import functools -from typing import Dict, Iterator, List, Tuple - -from absl import logging -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - opt_init_fn, opt_update_fn = optax.inject_hyperparams(optax.adamw)( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay, - learning_rate=0.0) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def l2_regularization(params, l2_decay_rank_threshold): - """Computes the squared l2 norm of the given parameters. - - This function will only filter for parameters with - rank >= l2_decay_rank_threshold. So if this threshold is set to 2, then all - 1d (and lower) parameter arrays, including all bias and batch norm params, - will be ignored in this computation. - - - Args: - params: Pytree containing parameters. - l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. - - Returns: - weight_l2: the squared l2 norm of all params matching the threshold. - """ - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum( - jnp.sum(x**2) - for x in weight_penalty_params - if x.ndim >= l2_decay_rank_threshold) - return weight_l2 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng, - lr): - optimizer_state.hyperparams['learning_rate'] = lr - - def _loss_fn(params): - """loss function used for training.""" - (logits, logit_paddings), new_model_state = workload.model_fn( - params, - batch, - model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], (logits, logit_paddings)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_clip = hyperparameters.grad_clip - grad_norm = jnp.sqrt(l2_regularization(grad, 0)) - scaled_grad = jax.tree_map( - lambda x: x / (grad_norm + _GRAD_CLIP_EPS) * grad_clip, grad) - grad = jax.lax.cond(grad_norm > grad_clip, - lambda _: scaled_grad, - lambda _: grad, - None) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del loss_type - - lr = get_learning_rate(global_step, hyperparameters) - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - per_device_rngs, - lr) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = outputs - - if global_step <= 1000 or global_step % 100 == 0: - logging.info('%d) loss = %0.3f, grad_norm = %0.3f lr = %0.6f', - global_step, - loss.mean(), - grad_norm.mean(), - lr) - - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'train_step_ctc_loss': loss.mean(), - 'grad_norm': grad_norm.mean(), - 'learning_rate': lr, - }, - global_step) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py deleted file mode 100644 index ce38d7509..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") -ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - optimizer = torch.optim.AdamW( - params=model_params.parameters(), - lr=0.0, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay) - return {'optimizer': optimizer} - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del model_state - del loss_type - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - current_model = current_param_container - - (logits, logits_padding), _ = workload.model_fn( - current_model, - batch, - None, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logits_padding)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - for g in optimizer.param_groups: - g['lr'] = get_learning_rate(global_step, hyperparameters) - if hasattr(hyperparameters, 'grad_clip'): - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=hyperparameters.grad_clip) - optimizer.step() - return optimizer_state, current_param_container, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json deleted file mode 100644 index 821288415..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "base_lr": {"feasible_points": [0.001997]}, - "beta1": {"feasible_points": [0.7132]}, - "beta2": {"feasible_points": [0.9982]}, - "epsilon": {"feasible_points": [1e-9]}, - "weight_decay": {"feasible_points":[0.026595]}, - "grad_clip": {"feasible_points": [5.0]}, - "warmup_steps" : {"feasible_points": [10000]}, - "training_steps" : {"feasible_points": [100000]} -} - diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py deleted file mode 100644 index f8a368f3f..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -import functools -from typing import Dict, Iterator, List, Tuple - -from absl import logging -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - opt_init_fn, opt_update_fn = optax.inject_hyperparams(optax.adamw)( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay, - learning_rate=0.0) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def l2_regularization(params, l2_decay_rank_threshold): - """Computes the squared l2 norm of the given parameters. - - This function will only filter for parameters with - rank >= l2_decay_rank_threshold. So if this threshold is set to 2, then all - 1d (and lower) parameter arrays, including all bias and batch norm params, - will be ignored in this computation. - - - Args: - params: Pytree containing parameters. - l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. - - Returns: - weight_l2: the squared l2 norm of all params matching the threshold. - """ - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum( - jnp.sum(x**2) - for x in weight_penalty_params - if x.ndim >= l2_decay_rank_threshold) - return weight_l2 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng, - lr): - optimizer_state.hyperparams['learning_rate'] = lr - - def _loss_fn(params): - """loss function used for training.""" - (logits, logit_paddings), new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logit_paddings)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt(l2_regularization(grad, 0)) - grad_clip = hyperparameters.grad_clip - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del loss_type - - lr = get_learning_rate(global_step, hyperparameters) - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - per_device_rngs, - lr) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = outputs - - if global_step <= 1000 or global_step % 100 == 0: - logging.info('%d) loss = %0.3f, grad_norm = %0.3f lr = %0.6f', - global_step, - loss.mean(), - grad_norm.mean(), - lr) - if workload.summary_writer is not None: - workload.summary_writer.scalar('train_step_ctc_loss', - loss.mean(), - global_step) - workload.summary_writer.scalar('grad_norm', grad_norm.mean(), global_step) - workload.summary_writer.scalar('learning_rate', lr, global_step) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del optimizer_state - del current_param_container - del global_step - del rng - del hyperparameters - del workload - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py deleted file mode 100644 index 9170086a5..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - optimizer = torch.optim.AdamW( - params=model_params.parameters(), - lr=0.0, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay) - return {'optimizer': optimizer} - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del model_state - del loss_type - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - current_model = current_param_container - - (logits, logits_padding), _ = workload.model_fn( - current_model, - batch, - None, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logits_padding)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - for g in optimizer.param_groups: - g['lr'] = get_learning_rate(global_step, hyperparameters) - if hasattr(hyperparameters, 'grad_clip'): - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=hyperparameters.grad_clip) - optimizer.step() - return optimizer_state, current_param_container, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json deleted file mode 100644 index d337200c7..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "base_lr": {"feasible_points": [0.002632520052132928]}, - "beta1": {"feasible_points": [0.9945481149103774]}, - "beta2": {"feasible_points": [0.996379002889742]}, - "epsilon": {"feasible_points": [1e-8]}, - "weight_decay": {"feasible_points":[0.107175616660346]}, - "grad_clip": {"feasible_points": [5.0]}, - "warmup_steps" : {"feasible_points": [3000]}, - "training_steps" : {"feasible_points": [60000]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/ogbg/__init__.py b/reference_algorithms/development_algorithms/ogbg/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py b/reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py b/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py deleted file mode 100644 index 28b512589..000000000 --- a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'ogbg': 2048} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adam optimizer.""" - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = opt_init_fn, opt_update_fn = optax.adam( - learning_rate=hyperparameters.learning_rate) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - del hyperparameters - - def _loss_fn(params): - logits_batch, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - mask_batch = batch['weights'] - loss_dict = workload.loss_fn(batch['targets'], logits_batch, mask_batch) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - pmapped_train_step = jax.pmap( - train_step, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, dropout_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py b/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py b/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py deleted file mode 100644 index 04f4baf9a..000000000 --- a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'ogbg': 32768} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adam optimizer.""" - del workload - del model_state - del rng - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), lr=hyperparameters.learning_rate), - } - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], logits, batch['weights']) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - optimizer_state['optimizer'].step() - - return optimizer_state, current_param_container, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json deleted file mode 100644 index d50cc00c5..000000000 --- a/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json +++ /dev/null @@ -1 +0,0 @@ -{"learning_rate": {"feasible_points": [1e-3]}} diff --git a/reference_algorithms/development_algorithms/wmt/__init__.py b/reference_algorithms/development_algorithms/wmt/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/tuning_search_space.json b/reference_algorithms/development_algorithms/wmt/tuning_search_space.json deleted file mode 100644 index ba3b24f8e..000000000 --- a/reference_algorithms/development_algorithms/wmt/tuning_search_space.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.0625]}, - "one_minus_beta_1": {"feasible_points": [0.1]}, - "dropout_rate": {"feasible_points": [0.1]}, - "aux_dropout_rate": {"feasible_points": [0.1]}, - "epsilon": {"feasible_points": [1e-9]} -} - diff --git a/reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py b/reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py b/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py deleted file mode 100644 index 9ef1580b2..000000000 --- a/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Training algorithm track submission functions for WMT.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - batch_sizes = {'wmt': 128} - return batch_sizes[workload_name] - - -def create_learning_rate_scheduler( - factors='constant * linear_warmup * rsqrt_decay', - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000): - """Creates learning rate schedule. - - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= jnp.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= jnp.sqrt(warmup_steps) - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) - elif name == 'cosine_decay': - progress = jnp.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= jnp.maximum(0.0, - 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) - else: - raise ValueError(f'Unknown factor {name}.') - return jnp.asarray(ret, dtype=jnp.float32) - - return step_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - learning_rate_fn = create_learning_rate_scheduler( - base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000) - opt_init_fn, opt_update_fn = optax.adam( - b1=1.0 - hyperparameters.one_minus_beta_1, - b2=0.98, - eps=hyperparameters.epsilon, - learning_rate=learning_rate_fn) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - in_axes=(None, None, 0, 0, 0, 0, None), - axis_name='batch', - static_broadcasted_argnums=(0, 1, 6)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rng, - hyperparameters): - """Perform a single training step.""" - del hyperparameters - - def _loss_fn(params): - """Loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=dropout_rng, - update_batch_norm=False) - targets = batch['targets'] - weights = batch['weights'] - loss_dict = workload.loss_fn(targets, logits, weights, label_smoothing=0.1) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, n_valid_examples - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, n_valid_examples), grad = grad_fn(current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = jax.lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del global_step - del model_state - del loss_type - - optimizer_state, opt_update_fn = optimizer_state - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, updated_params = pmapped_train_step( - workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rngs, - hyperparameters) - return (new_optimizer_state, opt_update_fn), updated_params, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py b/reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py b/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py deleted file mode 100644 index 2df681273..000000000 --- a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py +++ /dev/null @@ -1,167 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - batch_sizes = {'wmt': 128} - return batch_sizes[workload_name] - - -def create_learning_rate_scheduler( - factors='constant * linear_warmup * rsqrt_decay', - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000): - """Creates learning rate schedule. - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= np.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= np.sqrt(np.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= np.sqrt(warmup_steps) - ret /= np.sqrt(np.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) - elif name == 'cosine_decay': - progress = np.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) - else: - raise ValueError(f'Unknown factor {name}.') - return ret - - return step_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta_1, 0.98), - eps=hyperparameters.epsilon), - } - - optimizer_state['scheduler'] = create_learning_rate_scheduler( - base_learning_rate=hyperparameters.learning_rate) - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - - logits, _ = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=False) - - targets = batch['targets'] - weights = batch['weights'] - loss_dict = workload.loss_fn(targets, logits, weights, label_smoothing=0.1) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - lr = optimizer_state['scheduler'](global_step).item() - for g in optimizer.param_groups: - g['lr'] = lr - optimizer.step() - - return (optimizer_state, current_param_container, None) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index ae834f1f4..5c43b233b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -9,8 +9,8 @@ Assumes that each reference submission is using the external tuning ruleset and that it is defined in: # pylint: disable=line-too-long -"reference_algorithms/development_algorithms/{workload}/{workload}_{framework}/submission.py" -"reference_algorithms/development_algorithms/{workload}/tuning_search_space.json". +"reference_algorithms/target_setting_algorithms/{workload}/{workload}_{framework}/submission.py" +"reference_algorithms/target_setting_algorithms/{workload}/tuning_search_space.json". python3 tests/reference_algorithm_tests.py \ --workload=criteo1tb \ @@ -19,6 +19,7 @@ --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json """ + import copy import functools import importlib @@ -499,10 +500,10 @@ def _make_paths(repo_location, framework, workload_name): else: dataset_name = workload_name workload_dir = ( - f'{repo_location}/reference_algorithms/development_algorithms/' + f'{repo_location}/reference_algorithms/target_setting_algorithms/' f'{workload_name}') search_space_path = f'{workload_dir}/tuning_search_space.json' - submission_path = (f'reference_algorithms/development_algorithms/' + submission_path = (f'reference_algorithms/target_setting_algorithms/' f'{workload_name}/{dataset_name}_{framework}/' 'submission.py') full_submission_path = f'{repo_location}/{submission_path}' @@ -534,7 +535,7 @@ def test_submission(self): if FLAGS.tuning_search_space: raise ValueError('Cannot set --tuning_search_space and --all.') references_dir = ( - f'{repo_location}/reference_algorithms/development_algorithms') + f'{repo_location}/reference_algorithms/target_setting_algorithms') for workload_name in os.listdir(references_dir): for framework in ['jax', 'pytorch']: if framework == 'pytorch':