diff --git a/README.md b/README.md index 6ffbab6f7..197ba8a61 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,11 @@ The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. +## Pytorch Conformer CUDA OOM + +The conformer pytorch workload may run out of memory in current state. Please set the `submission_runner.py` flag `reduce_pytorch_max_split_size` to `True` as a temporary workaround if you encounter this issue. This will set 'max_split_size_mb:256'. Note that this will adversely impact the performance of the submission on this workload. See [tracking issue](https://github.com/mlcommons/algorithmic-efficiency/issues/497). + + # FAQS ## Setup and Platform diff --git a/submission_runner.py b/submission_runner.py index 656599a42..d92732145 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -149,6 +149,9 @@ None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') +flags.DEFINE_boolean('set_pytorch_max_split_size', + False, + 'If true, set pytorch max_split_size_mb to 256') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -602,6 +605,9 @@ def main(_): if FLAGS.workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + if FLAGS.set_pytorch_max_split_size: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR,