Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert DeepSpeed stuff #22899

Merged
merged 17 commits into from
Apr 20, 2023
109 changes: 62 additions & 47 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,69 +1544,84 @@ def _setup_devices(self) -> "torch.device":
self._n_gpu = 1
torch.cuda.set_device(device)
elif self.deepspeed:
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
# deepspeed inits torch.distributed internally
from .deepspeed import is_deepspeed_available

if not is_deepspeed_available():
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed

deepspeed.init_distributed(timeout=timedelta(seconds=self.ddp_timeout))

# workaround for setups like notebooks where the launcher can't be used,
# but deepspeed requires a dist env.
# env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))

device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
else:
self.distributed_state = PartialState(backend=self.xpu_backend)
self._n_gpu = 1
if not is_sagemaker_mp_enabled():
if not is_sagemaker_mp_enabled() and not self.deepspeed:
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and hasattr(self, "distributed_state")
and self.distributed_state.distributed_type == DistributedType.NO
):
logger.warning(
"torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)

if is_torch_tpu_available():
device = self.distributed_state.device
self._n_gpu = 0
elif is_sagemaker_dp_enabled():
self._n_gpu = 1
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
raise AssertionError(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled. Please install torch version >=1.12.0 on "
"your Apple silicon Mac running macOS 12.3 or later with a native "
"version (arm64) of Python"
)
if not self.deepspeed:
if is_torch_tpu_available():
device = self.distributed_state.device
self._n_gpu = 0
elif is_sagemaker_dp_enabled():
self._n_gpu = 1
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
raise AssertionError(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled. Please install torch version >=1.12.0 on "
"your Apple silicon Mac running macOS 12.3 or later with a native "
"version (arm64) of Python"
)
else:
raise AssertionError(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
else:
raise AssertionError(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
else:
if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"):
warnings.warn(
"We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)"
" on your MacOS machine. It has major fixes related to model correctness and performance"
" improvements for transformer based models. Please refer to"
" https://github.com/pytorch/pytorch/issues/82707 for more details."
)
device = torch.device("mps")
self._n_gpu = 1
if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"):
warnings.warn(
"We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)"
" on your MacOS machine. It has major fixes related to model correctness and performance"
" improvements for transformer based models. Please refer to"
" https://github.com/pytorch/pytorch/issues/82707 for more details."
)
device = torch.device("mps")
self._n_gpu = 1

else:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
# device = self.distributed_state.device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
if device.type == "cuda":
torch.cuda.set_device(device)
else:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
# device = self.distributed_state.device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
if device.type == "cuda":
torch.cuda.set_device(device)
return device

@property
Expand Down Expand Up @@ -1649,7 +1664,7 @@ def parallel_mode(self):
return ParallelMode.SAGEMAKER_MODEL_PARALLEL
elif is_sagemaker_dp_enabled():
return ParallelMode.SAGEMAKER_DATA_PARALLEL
elif hasattr(self, "distributed_state") and (self.distributed_state.distributed_type != DistributedType.NO):
elif self.deepspeed or self.distributed_state.distributed_type != DistributedType.NO:
return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1:
return ParallelMode.NOT_DISTRIBUTED
Expand Down