diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 292ecad3838664..5b8ffeafc7c8ea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -136,6 +136,7 @@ SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + XLA_FSDPV2_MIN_VERSION, PushInProgress, PushToHubMixin, can_return_loss, @@ -179,8 +180,14 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met - import torch_xla.distributed.spmd as xs - import torch_xla.runtime as xr + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr +else: + IS_XLA_FSDPV2_POST_2_2 = False if is_sagemaker_mp_enabled(): @@ -664,6 +671,8 @@ def __init__( self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False) if self.is_fsdp_xla_v2_enabled: + if not IS_XLA_FSDPV2_POST_2_2: + raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.") # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. # Tensor axis is just a placeholder where it will not be used in FSDPv2. num_devices = xr.global_runtime_device_count() diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index d33a6732458f71..121c4dc1361e4e 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -98,6 +98,7 @@ USE_JAX, USE_TF, USE_TORCH, + XLA_FSDPV2_MIN_VERSION, DummyObject, OptionalDependencyNotAvailable, _LazyModule, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 486df111856ae9..a8c45aeac33f16 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ACCELERATE_MIN_VERSION = "0.21.0" FSDP_MIN_VERSION = "1.12.0" +XLA_FSDPV2_MIN_VERSION = "2.2.0" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)