From 0f20678459b150526ce7f9de2923795cbb895d28 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 26 Sep 2023 01:36:35 +0000 Subject: [PATCH 01/48] set max split size --- submission_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/submission_runner.py b/submission_runner.py index 2289d39d3..8370fd9df 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -50,6 +50,9 @@ # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +# os.environ["CUDA_VISIBLE_DEVICES"]='0' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" + # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 6cf192a870be0e5340190eafa51a46661c3b2d6a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 01:18:05 +0000 Subject: [PATCH 02/48] tune max split size --- submission_runner.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 8370fd9df..24a1ea360 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -43,18 +43,14 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads + # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') # disable only for deepspeech if it works fine for other workloads. -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - -# os.environ["CUDA_VISIBLE_DEVICES"]='0' -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" - -# TODO(znado): make a nicer registry of workloads that lookup in. -BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=falseos.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"# TODO(znado): make a nicer registry of workloads that lookup in. +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Workload_path will be appended by '_pytorch' or '_jax' automatically. WORKLOADS = workloads.WORKLOADS @@ -209,9 +205,8 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'criteo1tb'] - eager_backend_workloads = [ - 'librispeech_conformer', 'librispeech_deepspeech' + compile_error_workloads = ['ogbg', 'criteo1tb', 'librispeech_conformer'] + eager_backend_workloads = ['librispeech_deepspeech' ] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: @@ -422,7 +417,6 @@ def train_once( _reset_cuda_mem() train_state['last_step_end_time'] = get_time() - metrics = {'eval_results': eval_results, 'global_step': global_step} if log_dir is not None: From b2f8ff91c4a15de1a0224754719cb768ab3467a5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 01:19:53 +0000 Subject: [PATCH 03/48] typo --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 24a1ea360..cf70746b8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -49,7 +49,7 @@ tf.config.set_visible_devices([], 'GPU') # disable only for deepspeech if it works fine for other workloads. -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=falseos.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"# TODO(znado): make a nicer registry of workloads that lookup in. +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Workload_path will be appended by '_pytorch' or '_jax' automatically. From 70625b09aabf0a2483dc9d61eb5addb1886b8ed5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 01:24:33 +0000 Subject: [PATCH 04/48] add back deleted block --- submission_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/submission_runner.py b/submission_runner.py index cf70746b8..c2b5bbd3a 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -55,6 +55,9 @@ # Workload_path will be appended by '_pytorch' or '_jax' automatically. WORKLOADS = workloads.WORKLOADS +# TODO(znado): make a nicer registry of workloads that lookup in. +BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR + flags.DEFINE_string( 'submission_path', None, From 179abba5b6af04ec2bbf3a150619852a4a969db1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 01:26:44 +0000 Subject: [PATCH 05/48] undo disable torch compile for conformer --- submission_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index c2b5bbd3a..272f17ca2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -52,12 +52,12 @@ os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' -# Workload_path will be appended by '_pytorch' or '_jax' automatically. -WORKLOADS = workloads.WORKLOADS - # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR +# Workload_path will be appended by '_pytorch' or '_jax' automatically. +WORKLOADS = workloads.WORKLOADS + flags.DEFINE_string( 'submission_path', None, @@ -420,6 +420,7 @@ def train_once( _reset_cuda_mem() train_state['last_step_end_time'] = get_time() + metrics = {'eval_results': eval_results, 'global_step': global_step} if log_dir is not None: From a2beafbfac4124172dd39f8ee9584d9bb287aee3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 02:15:10 +0000 Subject: [PATCH 06/48] remove whitespace --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 272f17ca2..568718798 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -420,7 +420,7 @@ def train_once( _reset_cuda_mem() train_state['last_step_end_time'] = get_time() - + metrics = {'eval_results': eval_results, 'global_step': global_step} if log_dir is not None: From 255d8350db0ff3fdca5a0b905997fcd53336547e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 02:38:40 +0000 Subject: [PATCH 07/48] remove training whitespace --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 568718798..846894203 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -420,7 +420,7 @@ def train_once( _reset_cuda_mem() train_state['last_step_end_time'] = get_time() - + metrics = {'eval_results': eval_results, 'global_step': global_step} if log_dir is not None: From b7f4cbc4f92241e85937de66ad9be46b3da45687 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 02:39:05 +0000 Subject: [PATCH 08/48] isort fix --- submission_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 846894203..99db81cc7 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -43,7 +43,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads - # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') From bb296028da8d8170d88a25684b69e8952be29e57 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 18:06:30 +0000 Subject: [PATCH 09/48] formatting --- submission_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 99db81cc7..cfbbd881a 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -208,8 +208,7 @@ def train_once( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = ['ogbg', 'criteo1tb', 'librispeech_conformer'] - eager_backend_workloads = ['librispeech_deepspeech' - ] + eager_backend_workloads = ['librispeech_deepspeech'] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: logging.warning( From 3738f35b98887ff35122f8f5d9fa8fd187fde532 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:02:18 +0000 Subject: [PATCH 10/48] print step hint --- algorithmic_efficiency/checkpoint_utils.py | 2 +- reference_algorithms/target_setting_algorithms/jax_nadamw.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index fb7449b99..3ebace5a4 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -192,7 +192,7 @@ def save_checkpoint(framework: str, """ if framework == 'jax': model_params = jax.device_get(jax_utils.unreplicate(model_params)) - opt_state, _ = optimizer_state + opt_state, _, _ = optimizer_state opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..dc0e3eb09 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import optax +from absl import logging from algorithmic_efficiency import spec from reference_algorithms.target_setting_algorithms import cosine_warmup @@ -152,6 +153,7 @@ def init_optimizer_state(workload: spec.Workload, del rng target_setting_step_hint = int(0.75 * workload.step_hint) + logging.info(f'target setting step hint: {target_setting_step_hint}') lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, hyperparameters) From 7f7891df9e7e29eb5482f85317d57790b08594c1 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 14:17:19 -0400 Subject: [PATCH 11/48] minor --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ba8db9ced..ae5da49bc 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -140,6 +140,7 @@ def _eval_batch_pmapped(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] + print(summed_loss) return summed_loss def _eval_batch(self, From f081dd10dfe968dc43cba0bfbaa7b4d3aa49786e Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 14:28:03 -0400 Subject: [PATCH 12/48] minor --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ae5da49bc..1ae818764 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -6,7 +6,7 @@ from flax import jax_utils import jax import jax.numpy as jnp - +import numpy as np from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models @@ -140,8 +140,7 @@ def _eval_batch_pmapped(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - print(summed_loss) - return summed_loss + return np.array(summed_loss, dtype=np.float64) def _eval_batch(self, params: spec.ParameterContainer, From 62d9ad746623d3eafe7d5248cdb375b232960a26 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 14:31:12 -0400 Subject: [PATCH 13/48] minor --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 1ae818764..f3dc0c66d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -140,14 +140,14 @@ def _eval_batch_pmapped(self, summed_loss = self.loss_fn( label_batch=batch['targets'], logits_batch=logits, mask_batch=weights)['summed'] - return np.array(summed_loss, dtype=np.float64) + return summed_loss def _eval_batch(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return self._eval_batch_pmapped(params, batch).sum() + return np.array(self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): From 19accc709a35c18132e73751b7f5763148d6946f Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 17:10:17 -0400 Subject: [PATCH 14/48] Lint fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index f3dc0c66d..dc1696bd0 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import numpy as np + from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models From a1844c768d62ba3592bac897aa92cbe15b7ef9cd Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Mon, 2 Oct 2023 17:13:36 -0400 Subject: [PATCH 15/48] Lint fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index dc1696bd0..a76a70289 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -148,7 +148,8 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array(self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + return np.array( + self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): From 0e4dd85ef78536e0f02e301eaa73876f4fb79ba6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 6 Oct 2023 23:33:06 +0000 Subject: [PATCH 16/48] make pytorch cuda alloc config specific to conformer --- submission_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index cfbbd881a..5fecdc870 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -49,7 +49,6 @@ # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR @@ -227,7 +226,9 @@ def train_once( else: logging.info('Performing `torch.compile`.') model_params = torch.compile(model_params) - + # Temporary fix for Conformer OOM + if flags.framework == 'pytorch' and flags.workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, From da89a8bb31080a29458206e60b7906152eea3a83 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 6 Oct 2023 23:42:01 +0000 Subject: [PATCH 17/48] tune max split size --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 5fecdc870..069d2ffda 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,7 +228,7 @@ def train_once( model_params = torch.compile(model_params) # Temporary fix for Conformer OOM if flags.framework == 'pytorch' and flags.workload == 'librispeech_conformer': - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, From 416b88dd2d621b2f99f69d58bde3f0eae216146d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 00:47:45 +0000 Subject: [PATCH 18/48] fix --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 069d2ffda..1460f5573 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -227,8 +227,8 @@ def train_once( logging.info('Performing `torch.compile`.') model_params = torch.compile(model_params) # Temporary fix for Conformer OOM - if flags.framework == 'pytorch' and flags.workload == 'librispeech_conformer': - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch' and FLAGS.workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, From 4600d78657c1074b9754918e56e231d6d3c42a44 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 00:50:09 +0000 Subject: [PATCH 19/48] reduce max split size --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 1460f5573..0a56f49d7 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -228,7 +228,7 @@ def train_once( model_params = torch.compile(model_params) # Temporary fix for Conformer OOM if FLAGS.framework == 'pytorch' and FLAGS.workload == 'librispeech_conformer': - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, From 7a764e145a00ca19d5f278b70706f37ced09649e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 00:55:53 +0000 Subject: [PATCH 20/48] move env var --- submission_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submission_runner.py b/submission_runner.py index 0a56f49d7..bbbe78b58 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -49,6 +49,7 @@ # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR From 1dbf3e4d8d62b8c392707309b7e0c1d84aae9a21 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:01:08 +0000 Subject: [PATCH 21/48] logging --- submission_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submission_runner.py b/submission_runner.py index bbbe78b58..db4c7341f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -230,6 +230,7 @@ def train_once( # Temporary fix for Conformer OOM if FLAGS.framework == 'pytorch' and FLAGS.workload == 'librispeech_conformer': os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + logging.info("AHHHHHHhhhhhhhhhhhhhhhh") logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, From 04f5c9419bd81b1d73e00c5889478a1ba421802c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:03:50 +0000 Subject: [PATCH 22/48] debugging --- submission_runner.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index db4c7341f..5a793dbf1 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -227,10 +227,6 @@ def train_once( else: logging.info('Performing `torch.compile`.') model_params = torch.compile(model_params) - # Temporary fix for Conformer OOM - if FLAGS.framework == 'pytorch' and FLAGS.workload == 'librispeech_conformer': - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' - logging.info("AHHHHHHhhhhhhhhhhhhhhhh") logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, @@ -583,8 +579,9 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] # Prevent OOM on librispeech conformer. - if FLAGS.workload == 'librispeech_conformer': - os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + # if FLAGS.workload == 'librispeech_conformer': + # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( From b0b9f404e3a556c1fd364b229ea05135f770b496 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:06:55 +0000 Subject: [PATCH 23/48] debugging --- submission_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 5a793dbf1..69d55952a 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -579,9 +579,9 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] # Prevent OOM on librispeech conformer. - # if FLAGS.workload == 'librispeech_conformer': - # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' - # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.workload == 'librispeech_conformer': + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( From 318202ea0e1e8de77ceef12e3e2d5999e9330692 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:35:31 +0000 Subject: [PATCH 24/48] debug logging --- algorithmic_efficiency/logger_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 2b3cf86f6..097d8979f 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -278,7 +278,8 @@ def get_meta_data(workload: spec.Workload) -> dict: def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) - write_json(meta_file_name, meta_data) + with open(meta_file_name, 'w') as f: + f.write(json.dumps(meta_data, indent=2)) class MetricLogger(object): From 3cec8c5e52918b1522b61e2b986e824ba1f3dc3a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:37:14 +0000 Subject: [PATCH 25/48] update --- algorithmic_efficiency/logger_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 097d8979f..f1e948c7f 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -277,7 +277,7 @@ def get_meta_data(workload: spec.Workload) -> dict: def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - meta_data.update({'rng_seed': rng_seed}) + # meta_data.update({'rng_seed': rng_seed}) with open(meta_file_name, 'w') as f: f.write(json.dumps(meta_data, indent=2)) From 4fc6e1cab0e614ea140e066e60b34ef6d6f6df83 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:50:26 +0000 Subject: [PATCH 26/48] update_logging --- algorithmic_efficiency/logger_utils.py | 9 +++++---- submission_runner.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index f1e948c7f..de5862b72 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,11 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data -def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): +def get_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - # meta_data.update({'rng_seed': rng_seed}) - with open(meta_file_name, 'w') as f: - f.write(json.dumps(meta_data, indent=2)) + meta_data.update({'rng_seed': rng_seed}) + # with open(meta_file_name, 'w') as f: + # f.write(json.dumps(meta_data, indent=2)) + return meta_data class MetricLogger(object): diff --git a/submission_runner.py b/submission_runner.py index 69d55952a..c3d0557e1 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -273,7 +273,8 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng_seed, preemption_count) + meta_data = logger_utils.save_meta_data(workload, rng_seed, preemption_count) + logger_utils.write_json(meta_file_name, meta_data) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) From 557bf0df368906675e4368e3dd9937eeac47068e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:51:45 +0000 Subject: [PATCH 27/48] fix --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index c3d0557e1..7b9e8423b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -273,7 +273,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - meta_data = logger_utils.save_meta_data(workload, rng_seed, preemption_count) + meta_data = logger_utils.get_meta_data(workload, rng_seed, preemption_count) logger_utils.write_json(meta_file_name, meta_data) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') From 2598d3945884f422b5f0ecfce5bfa0c2f4b09ef5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:56:07 +0000 Subject: [PATCH 28/48] fix --- algorithmic_efficiency/logger_utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index de5862b72..5373ed927 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -262,7 +262,7 @@ def _get_workload_properties(workload: spec.Workload) -> dict: return workload_properties -def get_meta_data(workload: spec.Workload) -> dict: +def get_meta_data(workload: spec.Workload, rng_seed: int = None) -> dict: meta_data = {} workload_properties = _get_workload_properties(workload) meta_data.update(workload_properties) @@ -272,14 +272,8 @@ def get_meta_data(workload: spec.Workload) -> dict: meta_data.update(system_software_info) system_hardware_info = _get_system_hardware_info() meta_data.update(system_hardware_info) - return meta_data - - -def get_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): - meta_data = get_meta_data(workload) - meta_data.update({'rng_seed': rng_seed}) - # with open(meta_file_name, 'w') as f: - # f.write(json.dumps(meta_data, indent=2)) + if rng_seed: + meta_data.update({'rng_seed': rng_seed}) return meta_data From 9418f4f3dd23052c08810007aad1816896b39754 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 7 Oct 2023 01:59:18 +0000 Subject: [PATCH 29/48] fix --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 7b9e8423b..344394abe 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -273,7 +273,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - meta_data = logger_utils.get_meta_data(workload, rng_seed, preemption_count) + meta_data = logger_utils.get_meta_data(workload, rng_seed) logger_utils.write_json(meta_file_name, meta_data) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') From 931337ddceb6ab1504cdded3ccba65baf6e98d0d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Oct 2023 17:51:13 +0000 Subject: [PATCH 30/48] remove logging --- reference_algorithms/target_setting_algorithms/jax_nadamw.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index dc0e3eb09..21f2a7b2b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp import optax -from absl import logging from algorithmic_efficiency import spec from reference_algorithms.target_setting_algorithms import cosine_warmup @@ -153,7 +152,6 @@ def init_optimizer_state(workload: spec.Workload, del rng target_setting_step_hint = int(0.75 * workload.step_hint) - logging.info(f'target setting step hint: {target_setting_step_hint}') lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, hyperparameters) From aeed475c980fd4855d2c92c96d84f80be3a483b4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Oct 2023 17:52:51 +0000 Subject: [PATCH 31/48] revert checkpoint utils debugging --- algorithmic_efficiency/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 3ebace5a4..fb7449b99 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -192,7 +192,7 @@ def save_checkpoint(framework: str, """ if framework == 'jax': model_params = jax.device_get(jax_utils.unreplicate(model_params)) - opt_state, _, _ = optimizer_state + opt_state, _ = optimizer_state opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: From 7098843f7d31ae057e0c06b1fd0cbe60181d2e85 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Oct 2023 19:25:14 +0000 Subject: [PATCH 32/48] extend max_allowed_runtime_sec for conformer --- .../workloads/librispeech_conformer/workload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index dc7fb912b..b6321e246 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -67,7 +67,8 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 61_068 # ~17 hours + return 122136 # ~34h extended max_allowed_run_time for conformer OOM issue + @property def eval_period_time_sec(self) -> int: From e663f8dffb79713ef809cefe9bc308c7ca672b5f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 9 Oct 2023 22:40:22 +0000 Subject: [PATCH 33/48] update speech targets --- .../workloads/librispeech_conformer/workload.py | 4 ++-- .../librispeech_deepspeech/librispeech_jax/workload.py | 4 ++-- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index dc7fb912b..2ad355975 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -19,14 +19,14 @@ def has_reached_validation_target(self, eval_result: Dict[str, @property def validation_target_value(self) -> float: - return 0.084952 + return 0.085884 def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return eval_result['test/wer'] < self.test_target_value @property def test_target_value(self) -> float: - return 0.053000 + return 0.052981 @property def loss_type(self) -> spec.LossType: diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index cb07075f4..ac6005225 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -53,11 +53,11 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: @property def validation_target_value(self) -> float: - return 0.118232 + return 0.119936 @property def test_target_value(self) -> float: - return 0.073397 + return 0.074143 @property def step_hint(self) -> int: diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 09a4b0aa4..bcdd78fb5 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -62,11 +62,11 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: @property def validation_target_value(self) -> float: - return 0.118232 + return 0.119936 @property def test_target_value(self) -> float: - return 0.073397 + return 0.074143 @property def step_hint(self) -> int: From fe2f560fec57d772941082e007169b080071a0fb Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 10 Oct 2023 13:13:51 -0400 Subject: [PATCH 34/48] Initial Commit --- CONTRIBUTING.md | 2 +- .../development_algorithms/README.md | 5 - .../criteo1tb/criteo1tb_jax/__init__.py | 0 .../criteo1tb/criteo1tb_jax/submission.py | 142 ------------ .../criteo1tb/criteo1tb_pytorch/__init__.py | 0 .../criteo1tb/criteo1tb_pytorch/submission.py | 110 --------- .../criteo1tb/tuning_search_space.json | 27 --- .../fastmri/__init__.py | 0 .../fastmri/fastmri_jax/__init__.py | 0 .../fastmri/fastmri_jax/submission.py | 145 ------------ .../fastmri/fastmri_pytorch/__init__.py | 0 .../fastmri/fastmri_pytorch/submission.py | 104 --------- .../fastmri/tuning_search_space.json | 7 - .../imagenet_resnet/__init__.py | 0 .../imagenet_resnet/imagenet_jax/__init__.py | 0 .../imagenet_jax/submission.py | 153 ------------- .../imagenet_pytorch/__init__.py | 0 .../imagenet_pytorch/submission.py | 119 ---------- .../imagenet_resnet/tuning_search_space.json | 7 - .../imagenet_vit/__init__.py | 0 .../imagenet_vit/imagenet_jax/__init__.py | 0 .../imagenet_vit/imagenet_jax/submission.py | 153 ------------- .../imagenet_vit/imagenet_pytorch/__init__.py | 0 .../imagenet_pytorch/submission.py | 117 ---------- .../imagenet_vit/tuning_search_space.json | 9 - .../librispeech_conformer/__init__.py | 0 .../librispeech_jax/__init__.py | 0 .../librispeech_jax/submission.py | 211 ------------------ .../librispeech_pytorch/__init__.py | 0 .../librispeech_pytorch/submission.py | 119 ---------- .../tuning_search_space.json | 11 - .../librispeech_deepspeech/__init__.py | 0 .../librispeech_jax/__init__.py | 0 .../librispeech_jax/submission.py | 206 ----------------- .../librispeech_pytorch/__init__.py | 0 .../librispeech_pytorch/submission.py | 116 ---------- .../tuning_search_space.json | 10 - .../development_algorithms/ogbg/__init__.py | 0 .../ogbg/ogbg_jax/__init__.py | 0 .../ogbg/ogbg_jax/submission.py | 122 ---------- .../ogbg/ogbg_pytorch/__init__.py | 0 .../ogbg/ogbg_pytorch/submission.py | 99 -------- .../ogbg/tuning_search_space.json | 1 - .../development_algorithms/wmt/__init__.py | 0 .../wmt/tuning_search_space.json | 8 - .../wmt/wmt_jax/__init__.py | 0 .../wmt/wmt_jax/submission.py | 194 ---------------- .../wmt/wmt_pytorch/__init__.py | 0 .../wmt/wmt_pytorch/submission.py | 167 -------------- tests/reference_algorithm_tests.py | 16 +- tests/submission_runner_test.py | 2 +- 51 files changed, 9 insertions(+), 2373 deletions(-) delete mode 100644 reference_algorithms/development_algorithms/README.md delete mode 100644 reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/fastmri/__init__.py delete mode 100644 reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/fastmri/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/imagenet_resnet/__init__.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/imagenet_vit/__init__.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/librispeech_conformer/__init__.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/ogbg/__init__.py delete mode 100644 reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py delete mode 100644 reference_algorithms/development_algorithms/ogbg/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/wmt/__init__.py delete mode 100644 reference_algorithms/development_algorithms/wmt/tuning_search_space.json delete mode 100644 reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py delete mode 100644 reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py delete mode 100644 reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py delete mode 100644 reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 33a14f83c..38867b369 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -216,7 +216,7 @@ pylint tests ## Unit and integration tests We run unit tests and integration tests as part of the of github actions as well. -You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/development_algorithms/`. +You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/target_setting_algorithms/`. ## Regression tests We also have regression tests available in [.github/workflows/regression_tests.yml](https://github.com/mlcommons/algorithmic-efficiency/tree/main/.github/workflows/regression_tests.yml) that can be run semi-automatically. diff --git a/reference_algorithms/development_algorithms/README.md b/reference_algorithms/development_algorithms/README.md deleted file mode 100644 index 12b1b1f8e..000000000 --- a/reference_algorithms/development_algorithms/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Development Algorithms - -These are various algorithms used during the testing and development of the codebase. - -These are not valid submissions, because they use a different hyperparameter settings and algorithms per workload. diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py deleted file mode 100644 index 4dea0c321..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Training algorithm track submission functions for Criteo1TB DLRM-Small.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 524_288 // 2 - - -def create_learning_rate_fn(workload: spec.Workload, - hparams: spec.Hyperparameters): - """Create learning rate schedule.""" - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hparams.learning_rate, - transition_steps=hparams.warmup_steps) - cosine_fn = optax.cosine_decay_schedule( - init_value=hparams.learning_rate, - decay_steps=(workload.step_hint - hparams.warmup_steps)) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[hparams.warmup_steps]) - return schedule_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - learning_rate_fn = create_learning_rate_fn(workload, hyperparameters) - opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=learning_rate_fn, - b1=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=False) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - return loss, new_model_state - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (loss, new_model_state), grad = grad_fn(current_param_container) - (loss, grad) = lax.pmean((loss, grad), axis_name='batch') - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - # del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py deleted file mode 100644 index d9d9c29b5..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - batch_sizes = {'criteo1tb': 524_288} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del rng - del model_state - - base_lr = hyperparameters.learning_rate - - optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=base_lr, - weight_decay=hyperparameters.weight_decay, - betas=(hyperparameters.beta1, 0.999)), - } - - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-12, - end_factor=1., - total_iters=hyperparameters.warmup_steps) - - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], - T_max=(workload.step_hint - hyperparameters.warmup_steps), - ) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_steps]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=None, - aux_dropout_rate=None, - update_batch_norm=False) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json deleted file mode 100644 index a30292bdb..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "learning_rate": { - "feasible_points": [ - 0.0065686501947063445 - ] - }, - "beta1": { - "feasible_points": [ - 0.8743797750166902 - ] - }, - "beta2": { - "feasible_points": [ - 0.9980006182116233 - ] - }, - "warmup_steps": { - "feasible_points": [ - 800 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.5301171352729387e-5 - ] - } -} diff --git a/reference_algorithms/development_algorithms/fastmri/__init__.py b/reference_algorithms/development_algorithms/fastmri/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py b/reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py b/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py deleted file mode 100644 index 73b020112..000000000 --- a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Training algorithm track submission functions for FastMRI in Jax.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 64 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - max_num_train_steps = 500 * steps_per_epoch - decay_epoch_period = hparams.lr_step_size * steps_per_epoch - decay_events = range(decay_epoch_period, - max_num_train_steps, - decay_epoch_period) - schedule_fn = optax.piecewise_constant_schedule( - init_value=hparams.learning_rate, - boundaries_and_scales={t: hparams.lr_gamma for t in decay_events}) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.rmsprop( - learning_rate=learning_rate_fn, - decay=0.99) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1, 4)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss - - grad_fn = jax.grad(_loss_fn) - grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del model_state - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, None - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py b/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py b/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py deleted file mode 100644 index 38828d4c3..000000000 --- a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Training algorithm track submission functions for FastMRI.""" - -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import StepLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'fastmri': 8} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - base_lr = hyperparameters.learning_rate * get_batch_size('fastmri') - optimizer_state = { - 'optimizer': - torch.optim.RMSprop( - model_params.parameters(), - lr=base_lr, - weight_decay=hyperparameters.l2), - } - - optimizer_state['scheduler'] = StepLR( - optimizer_state['optimizer'], - step_size=hyperparameters.lr_step_size, - gamma=hyperparameters.lr_gamma) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - outputs_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=outputs_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - steps_per_epoch = workload.num_train_examples // get_batch_size('fastmri') - if (global_step + 1) % steps_per_epoch == 0: - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json deleted file mode 100644 index 01e4e00c2..000000000 --- a/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.001]}, - "num_epochs": {"feasible_points": [50]}, - "l2": {"feasible_points": [0.0]}, - "lr_step_size": {"feasible_points": [40]}, - "lr_gamma": {"feasible_points": [0.1]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py deleted file mode 100644 index 9c686d524..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * \ - get_batch_size('imagenet_resnet') / 256. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.sgd( - nesterov=True, - momentum=hyperparameters.momentum, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - variables = {'params': params, **model_state} - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(variables['params']) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py deleted file mode 100644 index 694e924f7..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - batch_size = get_batch_size('imagenet_resnet') - base_lr = hyperparameters.learning_rate * batch_size / 256. - optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=base_lr, - momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2, - nesterov=True), - } - - steps_per_epoch = workload.num_train_examples // batch_size - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_epochs * steps_per_epoch) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json deleted file mode 100644 index da969416b..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.1]}, - "warmup_epochs": {"feasible_points": [5]}, - "num_epochs": {"feasible_points": [100]}, - "l2": {"feasible_points": [1e-4]}, - "momentum": {"feasible_points": [0.9]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/imagenet_vit/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py deleted file mode 100644 index 4d65d9675..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * \ - get_batch_size('imagenet_vit') / 1024. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_vit') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.adam( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py deleted file mode 100644 index eee2a01db..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - batch_size = get_batch_size('imagenet_vit') - base_lr = hyperparameters.learning_rate * batch_size / 1024. - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=base_lr, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon), - } - - steps_per_epoch = workload.num_train_examples // batch_size - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_epochs * steps_per_epoch) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json deleted file mode 100644 index e6cf84733..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "learning_rate": {"feasible_points": [1e-3]}, - "beta1": {"feasible_points": [0.9]}, - "beta2": {"feasible_points": [0.999]}, - "epsilon": {"feasible_points": [1e-8]}, - "num_epochs": {"feasible_points": [100]}, - "warmup_epochs": {"feasible_points": [5]}, - "l2": {"feasible_points": [1e-1]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py deleted file mode 100644 index ea314b820..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -import functools -from typing import Dict, Iterator, List, Tuple - -from absl import logging -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - opt_init_fn, opt_update_fn = optax.inject_hyperparams(optax.adamw)( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay, - learning_rate=0.0) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def l2_regularization(params, l2_decay_rank_threshold): - """Computes the squared l2 norm of the given parameters. - - This function will only filter for parameters with - rank >= l2_decay_rank_threshold. So if this threshold is set to 2, then all - 1d (and lower) parameter arrays, including all bias and batch norm params, - will be ignored in this computation. - - - Args: - params: Pytree containing parameters. - l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. - - Returns: - weight_l2: the squared l2 norm of all params matching the threshold. - """ - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum( - jnp.sum(x**2) - for x in weight_penalty_params - if x.ndim >= l2_decay_rank_threshold) - return weight_l2 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng, - lr): - optimizer_state.hyperparams['learning_rate'] = lr - - def _loss_fn(params): - """loss function used for training.""" - (logits, logit_paddings), new_model_state = workload.model_fn( - params, - batch, - model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], (logits, logit_paddings)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_clip = hyperparameters.grad_clip - grad_norm = jnp.sqrt(l2_regularization(grad, 0)) - scaled_grad = jax.tree_map( - lambda x: x / (grad_norm + _GRAD_CLIP_EPS) * grad_clip, grad) - grad = jax.lax.cond(grad_norm > grad_clip, - lambda _: scaled_grad, - lambda _: grad, - None) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del loss_type - - lr = get_learning_rate(global_step, hyperparameters) - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - per_device_rngs, - lr) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = outputs - - if global_step <= 1000 or global_step % 100 == 0: - logging.info('%d) loss = %0.3f, grad_norm = %0.3f lr = %0.6f', - global_step, - loss.mean(), - grad_norm.mean(), - lr) - - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'train_step_ctc_loss': loss.mean(), - 'grad_norm': grad_norm.mean(), - 'learning_rate': lr, - }, - global_step) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py deleted file mode 100644 index ce38d7509..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") -ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - optimizer = torch.optim.AdamW( - params=model_params.parameters(), - lr=0.0, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay) - return {'optimizer': optimizer} - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del model_state - del loss_type - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - current_model = current_param_container - - (logits, logits_padding), _ = workload.model_fn( - current_model, - batch, - None, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logits_padding)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - for g in optimizer.param_groups: - g['lr'] = get_learning_rate(global_step, hyperparameters) - if hasattr(hyperparameters, 'grad_clip'): - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=hyperparameters.grad_clip) - optimizer.step() - return optimizer_state, current_param_container, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json deleted file mode 100644 index 821288415..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "base_lr": {"feasible_points": [0.001997]}, - "beta1": {"feasible_points": [0.7132]}, - "beta2": {"feasible_points": [0.9982]}, - "epsilon": {"feasible_points": [1e-9]}, - "weight_decay": {"feasible_points":[0.026595]}, - "grad_clip": {"feasible_points": [5.0]}, - "warmup_steps" : {"feasible_points": [10000]}, - "training_steps" : {"feasible_points": [100000]} -} - diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py deleted file mode 100644 index f8a368f3f..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -import functools -from typing import Dict, Iterator, List, Tuple - -from absl import logging -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - opt_init_fn, opt_update_fn = optax.inject_hyperparams(optax.adamw)( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay, - learning_rate=0.0) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def l2_regularization(params, l2_decay_rank_threshold): - """Computes the squared l2 norm of the given parameters. - - This function will only filter for parameters with - rank >= l2_decay_rank_threshold. So if this threshold is set to 2, then all - 1d (and lower) parameter arrays, including all bias and batch norm params, - will be ignored in this computation. - - - Args: - params: Pytree containing parameters. - l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. - - Returns: - weight_l2: the squared l2 norm of all params matching the threshold. - """ - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum( - jnp.sum(x**2) - for x in weight_penalty_params - if x.ndim >= l2_decay_rank_threshold) - return weight_l2 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng, - lr): - optimizer_state.hyperparams['learning_rate'] = lr - - def _loss_fn(params): - """loss function used for training.""" - (logits, logit_paddings), new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logit_paddings)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt(l2_regularization(grad, 0)) - grad_clip = hyperparameters.grad_clip - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del loss_type - - lr = get_learning_rate(global_step, hyperparameters) - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - per_device_rngs, - lr) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = outputs - - if global_step <= 1000 or global_step % 100 == 0: - logging.info('%d) loss = %0.3f, grad_norm = %0.3f lr = %0.6f', - global_step, - loss.mean(), - grad_norm.mean(), - lr) - if workload.summary_writer is not None: - workload.summary_writer.scalar('train_step_ctc_loss', - loss.mean(), - global_step) - workload.summary_writer.scalar('grad_norm', grad_norm.mean(), global_step) - workload.summary_writer.scalar('learning_rate', lr, global_step) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del optimizer_state - del current_param_container - del global_step - del rng - del hyperparameters - del workload - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py deleted file mode 100644 index 9170086a5..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - optimizer = torch.optim.AdamW( - params=model_params.parameters(), - lr=0.0, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay) - return {'optimizer': optimizer} - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del model_state - del loss_type - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - current_model = current_param_container - - (logits, logits_padding), _ = workload.model_fn( - current_model, - batch, - None, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logits_padding)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - for g in optimizer.param_groups: - g['lr'] = get_learning_rate(global_step, hyperparameters) - if hasattr(hyperparameters, 'grad_clip'): - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=hyperparameters.grad_clip) - optimizer.step() - return optimizer_state, current_param_container, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json deleted file mode 100644 index d337200c7..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "base_lr": {"feasible_points": [0.002632520052132928]}, - "beta1": {"feasible_points": [0.9945481149103774]}, - "beta2": {"feasible_points": [0.996379002889742]}, - "epsilon": {"feasible_points": [1e-8]}, - "weight_decay": {"feasible_points":[0.107175616660346]}, - "grad_clip": {"feasible_points": [5.0]}, - "warmup_steps" : {"feasible_points": [3000]}, - "training_steps" : {"feasible_points": [60000]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/ogbg/__init__.py b/reference_algorithms/development_algorithms/ogbg/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py b/reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py b/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py deleted file mode 100644 index 28b512589..000000000 --- a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'ogbg': 2048} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adam optimizer.""" - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = opt_init_fn, opt_update_fn = optax.adam( - learning_rate=hyperparameters.learning_rate) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - del hyperparameters - - def _loss_fn(params): - logits_batch, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - mask_batch = batch['weights'] - loss_dict = workload.loss_fn(batch['targets'], logits_batch, mask_batch) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - pmapped_train_step = jax.pmap( - train_step, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, dropout_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py b/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py b/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py deleted file mode 100644 index 04f4baf9a..000000000 --- a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'ogbg': 32768} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adam optimizer.""" - del workload - del model_state - del rng - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), lr=hyperparameters.learning_rate), - } - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], logits, batch['weights']) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - optimizer_state['optimizer'].step() - - return optimizer_state, current_param_container, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json deleted file mode 100644 index d50cc00c5..000000000 --- a/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json +++ /dev/null @@ -1 +0,0 @@ -{"learning_rate": {"feasible_points": [1e-3]}} diff --git a/reference_algorithms/development_algorithms/wmt/__init__.py b/reference_algorithms/development_algorithms/wmt/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/tuning_search_space.json b/reference_algorithms/development_algorithms/wmt/tuning_search_space.json deleted file mode 100644 index ba3b24f8e..000000000 --- a/reference_algorithms/development_algorithms/wmt/tuning_search_space.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.0625]}, - "one_minus_beta_1": {"feasible_points": [0.1]}, - "dropout_rate": {"feasible_points": [0.1]}, - "aux_dropout_rate": {"feasible_points": [0.1]}, - "epsilon": {"feasible_points": [1e-9]} -} - diff --git a/reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py b/reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py b/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py deleted file mode 100644 index 9ef1580b2..000000000 --- a/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Training algorithm track submission functions for WMT.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - batch_sizes = {'wmt': 128} - return batch_sizes[workload_name] - - -def create_learning_rate_scheduler( - factors='constant * linear_warmup * rsqrt_decay', - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000): - """Creates learning rate schedule. - - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= jnp.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= jnp.sqrt(warmup_steps) - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) - elif name == 'cosine_decay': - progress = jnp.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= jnp.maximum(0.0, - 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) - else: - raise ValueError(f'Unknown factor {name}.') - return jnp.asarray(ret, dtype=jnp.float32) - - return step_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - learning_rate_fn = create_learning_rate_scheduler( - base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000) - opt_init_fn, opt_update_fn = optax.adam( - b1=1.0 - hyperparameters.one_minus_beta_1, - b2=0.98, - eps=hyperparameters.epsilon, - learning_rate=learning_rate_fn) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - in_axes=(None, None, 0, 0, 0, 0, None), - axis_name='batch', - static_broadcasted_argnums=(0, 1, 6)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rng, - hyperparameters): - """Perform a single training step.""" - del hyperparameters - - def _loss_fn(params): - """Loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=dropout_rng, - update_batch_norm=False) - targets = batch['targets'] - weights = batch['weights'] - loss_dict = workload.loss_fn(targets, logits, weights, label_smoothing=0.1) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, n_valid_examples - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, n_valid_examples), grad = grad_fn(current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = jax.lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del global_step - del model_state - del loss_type - - optimizer_state, opt_update_fn = optimizer_state - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, updated_params = pmapped_train_step( - workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rngs, - hyperparameters) - return (new_optimizer_state, opt_update_fn), updated_params, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py b/reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py b/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py deleted file mode 100644 index 2df681273..000000000 --- a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py +++ /dev/null @@ -1,167 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - batch_sizes = {'wmt': 128} - return batch_sizes[workload_name] - - -def create_learning_rate_scheduler( - factors='constant * linear_warmup * rsqrt_decay', - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000): - """Creates learning rate schedule. - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= np.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= np.sqrt(np.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= np.sqrt(warmup_steps) - ret /= np.sqrt(np.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) - elif name == 'cosine_decay': - progress = np.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) - else: - raise ValueError(f'Unknown factor {name}.') - return ret - - return step_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta_1, 0.98), - eps=hyperparameters.epsilon), - } - - optimizer_state['scheduler'] = create_learning_rate_scheduler( - base_learning_rate=hyperparameters.learning_rate) - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - - logits, _ = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=False) - - targets = batch['targets'] - weights = batch['weights'] - loss_dict = workload.loss_fn(targets, logits, weights, label_smoothing=0.1) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - lr = optimizer_state['scheduler'](global_step).item() - for g in optimizer.param_groups: - g['lr'] = lr - optimizer.step() - - return (optimizer_state, current_param_container, None) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index ae834f1f4..099eb6765 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -9,8 +9,8 @@ Assumes that each reference submission is using the external tuning ruleset and that it is defined in: # pylint: disable=line-too-long -"reference_algorithms/development_algorithms/{workload}/{workload}_{framework}/submission.py" -"reference_algorithms/development_algorithms/{workload}/tuning_search_space.json". +"reference_algorithms/target_setting_algorithms/{workload}/{workload}_{framework}/submission.py" +"reference_algorithms/target_setting_algorithms/{workload}/tuning_search_space.json". python3 tests/reference_algorithm_tests.py \ --workload=criteo1tb \ @@ -19,6 +19,7 @@ --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json """ + import copy import functools import importlib @@ -79,10 +80,7 @@ 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'], 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'], 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'], - 'ogbg': [ - 'train/accuracy', 'validation/loss', 'test/mean_average_precision' - ], - 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'], + } @@ -499,10 +497,10 @@ def _make_paths(repo_location, framework, workload_name): else: dataset_name = workload_name workload_dir = ( - f'{repo_location}/reference_algorithms/development_algorithms/' + f'{repo_location}/reference_algorithms/target_setting_algorithms/' f'{workload_name}') search_space_path = f'{workload_dir}/tuning_search_space.json' - submission_path = (f'reference_algorithms/development_algorithms/' + submission_path = (f'reference_algorithms/target_setting_algorithms/' f'{workload_name}/{dataset_name}_{framework}/' 'submission.py') full_submission_path = f'{repo_location}/{submission_path}' @@ -534,7 +532,7 @@ def test_submission(self): if FLAGS.tuning_search_space: raise ValueError('Cannot set --tuning_search_space and --all.') references_dir = ( - f'{repo_location}/reference_algorithms/development_algorithms') + f'{repo_location}/reference_algorithms/target_setting_algorithms') for workload_name in os.listdir(references_dir): for framework in ['jax', 'pytorch']: if framework == 'pytorch': diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index cc98e603e..75454dfc9 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -21,7 +21,7 @@ # (see https://github.com/google/model_search/pull/8). FLAGS(sys.argv) -_MNIST_DEV_ALGO_DIR = 'reference_algorithms/development_algorithms/mnist' +_MNIST_DEV_ALGO_DIR = 'reference_algorithms/target_setting_algorithms/mnist' class SubmissionRunnerTest(parameterized.TestCase): From e139cdc077f37b414e0b618b316e30b0f066e621 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 10 Oct 2023 20:34:05 -0700 Subject: [PATCH 35/48] Set criteo test targets based on external runs (#541) * set criteo test targets based on external runs * rounding * update --- algorithmic_efficiency/workloads/criteo1tb/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index ef971bb75..13bd308fb 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -35,14 +35,14 @@ def has_reached_validation_target(self, eval_result: Dict[str, @property def validation_target_value(self) -> float: - return 0.123649 + return 0.123735 def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return eval_result['test/loss'] < self.test_target_value @property def test_target_value(self) -> float: - return 0.126060 + return 0.126041 @property def loss_type(self) -> spec.LossType: From 725199200b32583a92924500bd9f6ceb647153d0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 11 Oct 2023 19:58:40 +0000 Subject: [PATCH 36/48] move logging in tuning loop --- submission_runner.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 47730d3fc..8605e74b8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -560,15 +560,14 @@ def score_submission_on_workload(workload: spec.Workload, save_checkpoints=save_checkpoints,) all_timings.append(timing) all_metrics.append(metrics) - score = min(all_timings) - for ti, _ in tuning_search_space_iter: - logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}') - logging.info(f'Hyperparameters: {tuning_search_space[ti]}') - logging.info(f'Metrics: {all_metrics[ti]}') - logging.info(f'Timing: {all_timings[ti]}') - num_evals = len(all_metrics[ti]['eval_results']) + logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') + logging.info(f'Hyperparameters: {tuning_search_space[hi]}') + logging.info(f'Metrics: {all_metrics[hi]}') + logging.info(f'Timing: {all_timings[hi]}') + num_evals = len(all_metrics[hi]['eval_results']) logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) + score = min(all_timings) else: if tuning_search_space is not None: raise ValueError( From 09ceeecda409730079e7991da0194acea68902ee Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 11 Oct 2023 21:15:53 +0000 Subject: [PATCH 37/48] remove conformer oom fixes from this branch --- .../workloads/librispeech_conformer/workload.py | 3 +-- submission_runner.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index f15b322dd..2ad355975 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -67,8 +67,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 122136 # ~34h extended max_allowed_run_time for conformer OOM issue - + return 61_068 # ~17 hours @property def eval_period_time_sec(self) -> int: diff --git a/submission_runner.py b/submission_runner.py index 545b2c016..56628d602 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -50,7 +50,6 @@ # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR @@ -216,8 +215,10 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'criteo1tb', 'librispeech_conformer'] - eager_backend_workloads = ['librispeech_deepspeech'] + compile_error_workloads = ['ogbg', 'criteo1tb'] + eager_backend_workloads = [ + 'librispeech_conformer', 'librispeech_deepspeech' + ] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: logging.warning( @@ -601,7 +602,6 @@ def main(_): # Prevent OOM on librispeech conformer. if FLAGS.workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( From a0b624e689dcd1909a22bb5d99fcaaa5de34e6f2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 11 Oct 2023 21:20:22 +0000 Subject: [PATCH 38/48] lint --- algorithmic_efficiency/logger_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 5373ed927..228b4fcf3 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -273,7 +273,7 @@ def get_meta_data(workload: spec.Workload, rng_seed: int = None) -> dict: system_hardware_info = _get_system_hardware_info() meta_data.update(system_hardware_info) if rng_seed: - meta_data.update({'rng_seed': rng_seed}) + meta_data.update({'rng_seed': rng_seed}) return meta_data From 39ac8af98e02728a095d47c5f0546dd96bbe1bed Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:42:02 -0400 Subject: [PATCH 39/48] Initial Commit --- algorithmic_efficiency/pytorch_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 1b8612fd1..8e079c41d 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -27,8 +27,6 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. - # Setting the corresponding environment variable here has no effect; it has to - # be done before jax and tensorflow (!) are imported for the first time. jax.config.update('jax_platforms', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." From 66af5036b0fee86105c6d3b3b8703edaa9628660 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:44:19 -0400 Subject: [PATCH 40/48] Add platform --- algorithmic_efficiency/pytorch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 8e079c41d..4f6c254bd 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -28,6 +28,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. jax.config.update('jax_platforms', 'cpu') + jax.config.update('jax_platform_name', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." torch.backends.cudnn.benchmark = True From b6af8a03bb063f2c377eca01895722da791aa0a1 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:46:38 -0400 Subject: [PATCH 41/48] minor --- submission_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submission_runner.py b/submission_runner.py index 8605e74b8..4144e5b9b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,6 +28,7 @@ from absl import flags from absl import logging import jax +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf import torch import torch.distributed as dist From df2b1fa41118950b1950ecc92ccdf0baa6d11719 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:48:25 -0400 Subject: [PATCH 42/48] minor --- submission_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 4144e5b9b..2d185a37f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,7 +28,6 @@ from absl import flags from absl import logging import jax -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf import torch import torch.distributed as dist @@ -45,6 +44,9 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +import tensorflow as tf # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') From bdce3a6b707a19bcae5b9cc302bd0f57a70669b0 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:52:19 -0400 Subject: [PATCH 43/48] Lint fix --- submission_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 2d185a37f..ee875f559 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,7 +28,6 @@ from absl import flags from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist @@ -44,9 +43,9 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. import tensorflow as tf + # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') @@ -350,8 +349,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() From 59741e07831b09281de7d3cb3c40332978543e73 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:55:22 -0400 Subject: [PATCH 44/48] Lint fix --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index ee875f559..e02eb0238 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -349,8 +349,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() From 3237d654110b854ed665e3dfad06c3ed8a126ecf Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:56:47 -0400 Subject: [PATCH 45/48] Replace order --- submission_runner.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index e02eb0238..de094f984 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,13 @@ import torch import torch.distributed as dist +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +import tensorflow as tf + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + from algorithmic_efficiency import checkpoint_utils from algorithmic_efficiency import halton from algorithmic_efficiency import logger_utils @@ -43,13 +50,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -import tensorflow as tf - -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' From 3706caf28fd788249448fd38fec6c87269e66e26 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 10:16:30 -0400 Subject: [PATCH 46/48] Add discarded workloads --- tests/reference_algorithm_tests.py | 5 ++++- tests/submission_runner_test.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 099eb6765..5c43b233b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -80,7 +80,10 @@ 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'], 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'], 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'], - + 'ogbg': [ + 'train/accuracy', 'validation/loss', 'test/mean_average_precision' + ], + 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'], } diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index 75454dfc9..cc98e603e 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -21,7 +21,7 @@ # (see https://github.com/google/model_search/pull/8). FLAGS(sys.argv) -_MNIST_DEV_ALGO_DIR = 'reference_algorithms/target_setting_algorithms/mnist' +_MNIST_DEV_ALGO_DIR = 'reference_algorithms/development_algorithms/mnist' class SubmissionRunnerTest(parameterized.TestCase): From 061d5b37243a76bdc18c33a81f3a2726c196cc91 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 13 Oct 2023 20:46:57 +0000 Subject: [PATCH 47/48] pr feedback --- algorithmic_efficiency/logger_utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 228b4fcf3..4a3c2ec4b 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -9,7 +9,7 @@ import shutil import subprocess import sys -from typing import Any, Optional +from typing import Any, Optional, Dict from absl import flags from clu import metric_writers @@ -96,14 +96,14 @@ def write_hparams(hparams: spec.Hyperparameters, return hparams -def write_json(name: str, log_dict: dict, indent: int = 2) -> None: +def write_json(name: str, log_dict: Dict, indent: int = 2) -> None: if RANK == 0: with open(name, 'w') as f: f.write(json.dumps(log_dict, indent=indent)) def write_to_csv( - metrics: dict, + metrics: Dict, csv_path: str, ) -> None: try: @@ -120,7 +120,7 @@ def write_to_csv( return -def _get_utilization() -> dict: +def _get_utilization() -> Dict: util_data = {} # CPU @@ -180,7 +180,7 @@ def _get_utilization() -> dict: return util_data -def _get_system_hardware_info() -> dict: +def _get_system_hardware_info() -> Dict: system_hardware_info = {} try: system_hardware_info['cpu_model_name'] = _get_cpu_model_name() @@ -200,7 +200,7 @@ def _get_system_hardware_info() -> dict: return system_hardware_info -def _get_system_software_info() -> dict: +def _get_system_software_info() -> Dict: system_software_info = {} system_software_info['os_platform'] = \ @@ -243,7 +243,7 @@ def _is_primitive_type(item: Any) -> bool: return isinstance(item, primitive) -def _get_workload_properties(workload: spec.Workload) -> dict: +def _get_workload_properties(workload: spec.Workload) -> Dict: workload_properties = {} skip_list = ['param_shapes', 'model_params_types'] keys = [ @@ -262,7 +262,8 @@ def _get_workload_properties(workload: spec.Workload) -> dict: return workload_properties -def get_meta_data(workload: spec.Workload, rng_seed: int = None) -> dict: +def get_meta_data(workload: spec.Workload, + rng_seed: Optional[int] = None) -> Dict: meta_data = {} workload_properties = _get_workload_properties(workload) meta_data.update(workload_properties) @@ -272,7 +273,7 @@ def get_meta_data(workload: spec.Workload, rng_seed: int = None) -> dict: meta_data.update(system_software_info) system_hardware_info = _get_system_hardware_info() meta_data.update(system_hardware_info) - if rng_seed: + if rng_seed is not None: meta_data.update({'rng_seed': rng_seed}) return meta_data @@ -304,7 +305,7 @@ def __init__(self, wandb.config.update(hyperparameters._asdict()) def append_scalar_metrics(self, - metrics: dict, + metrics: Dict, global_step: int, preemption_count: Optional[int] = None, is_eval: bool = False) -> None: From a4bb0f0d7c4945fa52a4082f99b530280647eeb1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 13 Oct 2023 20:51:45 +0000 Subject: [PATCH 48/48] isort --- algorithmic_efficiency/logger_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 4a3c2ec4b..b7bde226a 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -9,7 +9,7 @@ import shutil import subprocess import sys -from typing import Any, Optional, Dict +from typing import Any, Dict, Optional from absl import flags from clu import metric_writers