diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index e30274af0424..daa7e885d272 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -130,58 +130,11 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): fill_only = partialmethod(fill_match, must_match=False) - def override_training_args_from_deepspeed(self, args): - """ - Override TrainingArguments based on DeepSpeed config values to ensure compatibility. - - This method ensures that the DeepSpeed config takes precedence over TrainingArguments - defaults when there are conflicts, particularly for mixed precision settings. - - Args: - args: TrainingArguments object to potentially modify - """ - # Check precision settings in DeepSpeed config and override TrainingArguments accordingly - # Only override defaults, not explicit user settings - - # Check if user explicitly set precision options (we assume defaults are False) - user_set_fp16 = args.fp16 is True - user_set_bf16 = args.bf16 is True - - if self.is_true("fp16.enabled"): - # DeepSpeed config explicitly enables fp16 - if not user_set_fp16 and not user_set_bf16: - # User didn't explicitly set either, so apply DeepSpeed config - args.fp16 = True - args.bf16 = False - elif user_set_bf16 and not user_set_fp16: - # User explicitly chose bf16, but DeepSpeed config wants fp16 - # This is a potential conflict - let user choice win but log a warning - pass # Keep user's bf16=True, fp16=False - elif self.is_true("bf16.enabled"): - # DeepSpeed config explicitly enables bf16 - if not user_set_fp16 and not user_set_bf16: - # User didn't explicitly set either, so apply DeepSpeed config - args.bf16 = True - args.fp16 = False - elif user_set_fp16 and not user_set_bf16: - # User explicitly chose fp16, but DeepSpeed config wants bf16 - # This is a potential conflict - let user choice win but log a warning - pass # Keep user's fp16=True, bf16=False - elif self.is_false("fp16.enabled") and self.is_false("bf16.enabled"): - # Both are explicitly disabled in DeepSpeed config - if not user_set_fp16 and not user_set_bf16: - # User didn't explicitly set either, so apply DeepSpeed config (fp32) - args.fp16 = False - args.bf16 = False - def trainer_config_process(self, args, auto_find_batch_size=False): """ Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object creation. """ - # First, override TrainingArguments based on DeepSpeed config to ensure compatibility - self.override_training_args_from_deepspeed(args) - # DeepSpeed does: # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 3520ea5702aa..1d989608a6c3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1689,8 +1689,13 @@ def __post_init__(self): torch.backends.cudnn.allow_tf32 = False # no need to assert on else - # NOTE: Mixed precision environment variable setting moved to after DeepSpeed processing - # to ensure DeepSpeed config can override TrainingArguments defaults + # if training args is specified, it will override the one specified in the accelerate config + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype if self.report_to is None: logger.info( @@ -1877,15 +1882,6 @@ def __post_init__(self): self.deepspeed_plugin.set_mixed_precision(mixed_precision) self.deepspeed_plugin.set_deepspeed_weakref() - # Set mixed precision environment variable after DeepSpeed processing - # This ensures DeepSpeed config overrides have been applied to fp16/bf16 settings - mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") - if self.fp16: - mixed_precision_dtype = "fp16" - elif self.bf16: - mixed_precision_dtype = "bf16" - os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype - if self.use_cpu: self.dataloader_pin_memory = False diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index e3dc9fc08c99..99b1450a0d59 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -1431,50 +1431,3 @@ def test_clm_from_config_zero3_fp16(self): with CaptureStderr() as cs: execute_subprocess_async(cmd, env=self.get_env()) self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) - - -@require_deepspeed -class TestDeepSpeedMixedPrecisionPrecedence(TestCasePlus): - """Test DeepSpeed mixed precision precedence over Accelerate defaults.""" - - def setUp(self): - super().setUp() - unset_hf_deepspeed_config() - - def tearDown(self): - super().tearDown() - unset_hf_deepspeed_config() - - def test_deepspeed_fp16_overrides_defaults(self): - """Test that DeepSpeed fp16 config overrides TrainingArguments defaults""" - from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig - - args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False) - ds_config = {"fp16": {"enabled": True}, "bf16": {"enabled": False}, "zero_optimization": {"stage": 2}} - hf_ds_config = HfTrainerDeepSpeedConfig(ds_config) - hf_ds_config.trainer_config_process(args) - self.assertTrue(args.fp16) - self.assertFalse(args.bf16) - - def test_deepspeed_bf16_overrides_defaults(self): - """Test that DeepSpeed bf16 config overrides TrainingArguments defaults""" - from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig - - args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False) - ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}} - hf_ds_config = HfTrainerDeepSpeedConfig(ds_config) - hf_ds_config.trainer_config_process(args) - self.assertTrue(args.bf16) - self.assertFalse(args.fp16) - - def test_user_explicit_settings_preserved(self): - """Test that explicit user settings are preserved over DeepSpeed config""" - from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig - - args = TrainingArguments(output_dir="./test_output", fp16=True, bf16=False) # User explicit - ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}} - hf_ds_config = HfTrainerDeepSpeedConfig(ds_config) - hf_ds_config.trainer_config_process(args) - # User's explicit choice should be preserved - self.assertTrue(args.fp16) - self.assertFalse(args.bf16)