From 39ac8af98e02728a095d47c5f0546dd96bbe1bed Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:42:02 -0400 Subject: [PATCH 1/7] Initial Commit --- algorithmic_efficiency/pytorch_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 1b8612fd1..8e079c41d 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -27,8 +27,6 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. - # Setting the corresponding environment variable here has no effect; it has to - # be done before jax and tensorflow (!) are imported for the first time. jax.config.update('jax_platforms', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." From 66af5036b0fee86105c6d3b3b8703edaa9628660 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:44:19 -0400 Subject: [PATCH 2/7] Add platform --- algorithmic_efficiency/pytorch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 8e079c41d..4f6c254bd 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -28,6 +28,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. jax.config.update('jax_platforms', 'cpu') + jax.config.update('jax_platform_name', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." torch.backends.cudnn.benchmark = True From b6af8a03bb063f2c377eca01895722da791aa0a1 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:46:38 -0400 Subject: [PATCH 3/7] minor --- submission_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submission_runner.py b/submission_runner.py index 8605e74b8..4144e5b9b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,6 +28,7 @@ from absl import flags from absl import logging import jax +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf import torch import torch.distributed as dist From df2b1fa41118950b1950ecc92ccdf0baa6d11719 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:48:25 -0400 Subject: [PATCH 4/7] minor --- submission_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 4144e5b9b..2d185a37f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,7 +28,6 @@ from absl import flags from absl import logging import jax -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" import tensorflow as tf import torch import torch.distributed as dist @@ -45,6 +44,9 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +import tensorflow as tf # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') From bdce3a6b707a19bcae5b9cc302bd0f57a70669b0 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:52:19 -0400 Subject: [PATCH 5/7] Lint fix --- submission_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 2d185a37f..ee875f559 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,7 +28,6 @@ from absl import flags from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist @@ -44,9 +43,9 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. import tensorflow as tf + # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') @@ -350,8 +349,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() From 59741e07831b09281de7d3cb3c40332978543e73 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:55:22 -0400 Subject: [PATCH 6/7] Lint fix --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index ee875f559..e02eb0238 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -349,8 +349,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() From 3237d654110b854ed665e3dfad06c3ed8a126ecf Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Thu, 12 Oct 2023 09:56:47 -0400 Subject: [PATCH 7/7] Replace order --- submission_runner.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index e02eb0238..de094f984 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,13 @@ import torch import torch.distributed as dist +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +import tensorflow as tf + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + from algorithmic_efficiency import checkpoint_utils from algorithmic_efficiency import halton from algorithmic_efficiency import logger_utils @@ -43,13 +50,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -import tensorflow as tf - -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'