Skip to content

Commit

Permalink
Guard XLA version imports (#30167)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored and Ita Zaporozhets committed May 14, 2024
1 parent 8c3019d commit 38615d8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
XLA_FSDPV2_MIN_VERSION,
PushInProgress,
PushToHubMixin,
can_return_loss,
Expand Down Expand Up @@ -179,8 +180,14 @@
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from torch_xla import __version__ as XLA_VERSION

IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False


if is_sagemaker_mp_enabled():
Expand Down Expand Up @@ -664,6 +671,8 @@ def __init__(

self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled:
if not IS_XLA_FSDPV2_POST_2_2:
raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
USE_JAX,
USE_TF,
USE_TORCH,
XLA_FSDPV2_MIN_VERSION,
DummyObject,
OptionalDependencyNotAvailable,
_LazyModule,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[

ACCELERATE_MIN_VERSION = "0.21.0"
FSDP_MIN_VERSION = "1.12.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"


_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
Expand Down

0 comments on commit 38615d8

Please sign in to comment.