Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,58 +130,11 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):

fill_only = partialmethod(fill_match, must_match=False)

def override_training_args_from_deepspeed(self, args):
"""
Override TrainingArguments based on DeepSpeed config values to ensure compatibility.

This method ensures that the DeepSpeed config takes precedence over TrainingArguments
defaults when there are conflicts, particularly for mixed precision settings.

Args:
args: TrainingArguments object to potentially modify
"""
# Check precision settings in DeepSpeed config and override TrainingArguments accordingly
# Only override defaults, not explicit user settings

# Check if user explicitly set precision options (we assume defaults are False)
user_set_fp16 = args.fp16 is True
user_set_bf16 = args.bf16 is True

if self.is_true("fp16.enabled"):
# DeepSpeed config explicitly enables fp16
if not user_set_fp16 and not user_set_bf16:
# User didn't explicitly set either, so apply DeepSpeed config
args.fp16 = True
args.bf16 = False
elif user_set_bf16 and not user_set_fp16:
# User explicitly chose bf16, but DeepSpeed config wants fp16
# This is a potential conflict - let user choice win but log a warning
pass # Keep user's bf16=True, fp16=False
elif self.is_true("bf16.enabled"):
# DeepSpeed config explicitly enables bf16
if not user_set_fp16 and not user_set_bf16:
# User didn't explicitly set either, so apply DeepSpeed config
args.bf16 = True
args.fp16 = False
elif user_set_fp16 and not user_set_bf16:
# User explicitly chose fp16, but DeepSpeed config wants bf16
# This is a potential conflict - let user choice win but log a warning
pass # Keep user's fp16=True, bf16=False
elif self.is_false("fp16.enabled") and self.is_false("bf16.enabled"):
# Both are explicitly disabled in DeepSpeed config
if not user_set_fp16 and not user_set_bf16:
# User didn't explicitly set either, so apply DeepSpeed config (fp32)
args.fp16 = False
args.bf16 = False

def trainer_config_process(self, args, auto_find_batch_size=False):
"""
Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
creation.
"""
# First, override TrainingArguments based on DeepSpeed config to ensure compatibility
self.override_training_args_from_deepspeed(args)

# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
Expand Down
18 changes: 7 additions & 11 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,8 +1689,13 @@ def __post_init__(self):
torch.backends.cudnn.allow_tf32 = False
# no need to assert on else

# NOTE: Mixed precision environment variable setting moved to after DeepSpeed processing
# to ensure DeepSpeed config can override TrainingArguments defaults
# if training args is specified, it will override the one specified in the accelerate config
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype

if self.report_to is None:
logger.info(
Expand Down Expand Up @@ -1877,15 +1882,6 @@ def __post_init__(self):
self.deepspeed_plugin.set_mixed_precision(mixed_precision)
self.deepspeed_plugin.set_deepspeed_weakref()

# Set mixed precision environment variable after DeepSpeed processing
# This ensures DeepSpeed config overrides have been applied to fp16/bf16 settings
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype

if self.use_cpu:
self.dataloader_pin_memory = False

Expand Down
47 changes: 0 additions & 47 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,50 +1431,3 @@ def test_clm_from_config_zero3_fp16(self):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)


@require_deepspeed
class TestDeepSpeedMixedPrecisionPrecedence(TestCasePlus):
"""Test DeepSpeed mixed precision precedence over Accelerate defaults."""

def setUp(self):
super().setUp()
unset_hf_deepspeed_config()

def tearDown(self):
super().tearDown()
unset_hf_deepspeed_config()

def test_deepspeed_fp16_overrides_defaults(self):
"""Test that DeepSpeed fp16 config overrides TrainingArguments defaults"""
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False)
ds_config = {"fp16": {"enabled": True}, "bf16": {"enabled": False}, "zero_optimization": {"stage": 2}}
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
hf_ds_config.trainer_config_process(args)
self.assertTrue(args.fp16)
self.assertFalse(args.bf16)

def test_deepspeed_bf16_overrides_defaults(self):
"""Test that DeepSpeed bf16 config overrides TrainingArguments defaults"""
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False)
ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}}
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
hf_ds_config.trainer_config_process(args)
self.assertTrue(args.bf16)
self.assertFalse(args.fp16)

def test_user_explicit_settings_preserved(self):
"""Test that explicit user settings are preserved over DeepSpeed config"""
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

args = TrainingArguments(output_dir="./test_output", fp16=True, bf16=False) # User explicit
ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}}
hf_ds_config = HfTrainerDeepSpeedConfig(ds_config)
hf_ds_config.trainer_config_process(args)
# User's explicit choice should be preserved
self.assertTrue(args.fp16)
self.assertFalse(args.bf16)