Skip to content

Commit

Permalink
Fix training args transformers compatibility (#1770)
Browse files Browse the repository at this point in the history
* Fix training args compatibility with transformers v4.38.0 or lower

* format
  • Loading branch information
echarlaix committed Mar 25, 2024
1 parent 66e30ad commit 8a95414
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions optimum/onnxruntime/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@
)
from transformers.utils.generic import strtobool

from ..utils.import_utils import check_if_transformers_greater


if is_torch_available():
import torch

if is_accelerate_available():
if is_accelerate_available() and check_if_transformers_greater("4.38.0"):
from transformers.trainer_pt_utils import AcceleratorConfig


Expand Down Expand Up @@ -449,7 +451,7 @@ def __post_init__(self):
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")

if is_accelerate_available():
if is_accelerate_available() and check_if_transformers_greater("4.38.0"):
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
Expand Down

0 comments on commit 8a95414

Please sign in to comment.