diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 1b8612fd1..4f6c254bd 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -27,9 +27,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. - # Setting the corresponding environment variable here has no effect; it has to - # be done before jax and tensorflow (!) are imported for the first time. jax.config.update('jax_platforms', 'cpu') + jax.config.update('jax_platform_name', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." torch.backends.cudnn.benchmark = True diff --git a/submission_runner.py b/submission_runner.py index 8605e74b8..de094f984 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,10 +28,16 @@ from absl import flags from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +import tensorflow as tf + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + from algorithmic_efficiency import checkpoint_utils from algorithmic_efficiency import halton from algorithmic_efficiency import logger_utils @@ -44,10 +50,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'