Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Only use CPU for Jax to avoid memory issues.
# Setting the corresponding environment variable here has no effect; it has to
# be done before jax and tensorflow (!) are imported for the first time.
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_platform_name', 'cpu')
# From the docs: "(...) causes cuDNN to benchmark multiple convolution
# algorithms and select the fastest."
torch.backends.cudnn.benchmark = True
Expand Down
12 changes: 7 additions & 5 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,16 @@
from absl import flags
from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
import tensorflow as tf

# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.set_visible_devices([], 'GPU')

from algorithmic_efficiency import checkpoint_utils
from algorithmic_efficiency import halton
from algorithmic_efficiency import logger_utils
Expand All @@ -44,10 +50,6 @@
from algorithmic_efficiency.pytorch_utils import sync_ddp_time
from algorithmic_efficiency.workloads import workloads

# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.set_visible_devices([], 'GPU')

# disable only for deepspeech if it works fine for other workloads.
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'

Expand Down