From 10fab90fe22b21619ed1136792102a6a8142c4ed Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 12 Apr 2023 17:42:50 +0200 Subject: [PATCH] `torch.distributed` group initialization for `torch_neuron` disabled when `optimum-neuron` is installed (#22728) * Make the process group initialization not happen if optimum_neuron is installed * Add warning * Remove list and added warning --- src/transformers/training_args.py | 29 ++++++++++++++++++-------- src/transformers/utils/import_utils.py | 4 ++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 28387885de166..088eb06b790f6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -54,8 +54,13 @@ logging, requires_backends, ) +from .utils.import_utils import is_optimum_neuron_available +logger = logging.get_logger(__name__) +log_levels = logging.get_log_levels_dict().copy() +trainer_log_levels = dict(**log_levels, passive=-1) + if is_torch_available(): import torch import torch.distributed as dist @@ -67,12 +72,23 @@ # torchrun support # https://github.com/pytorch/xla/pull/3609 if os.environ.get("TORCHELASTIC_RUN_ID"): - import torch_xla.distributed.xla_backend as xbn + if is_optimum_neuron_available(): + logger.info( + "Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this " + "will fail otherwise." + ) + else: + logger.warning( + "Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform " + "training on AWS Trainium instances. More information here: " + "https://github.com/huggingface/optimum-neuron" + ) + import torch_xla.distributed.xla_backend as xbn - if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - torch.distributed.init_process_group(backend="xla") if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") + torch.distributed.init_process_group(backend="xla") + if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): + raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") if is_sagemaker_mp_enabled(): @@ -81,11 +97,6 @@ smp.init() -logger = logging.get_logger(__name__) -log_levels = logging.get_log_levels_dict().copy() -trainer_log_levels = dict(**log_levels, passive=-1) - - def default_logdir() -> str: """ Same default as PyTorch diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 10ca9dba00856..f250a670a9f4b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -583,6 +583,10 @@ def is_optimum_available(): return importlib.util.find_spec("optimum") is not None +def is_optimum_neuron_available(): + return importlib.util.find_spec("optimum.neuron") is not None + + def is_safetensors_available(): if is_torch_available(): if version.parse(_torch_version) >= version.parse("1.10"):