diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2bb5220ab85fd..43bc33e7b034d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2030,10 +2030,15 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) - is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any( - FSDP_MODEL_NAME in folder_name - for folder_name in os.listdir(resume_from_checkpoint) - if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) ) if is_fsdp_ckpt and not self.is_fsdp_enabled: diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 2a9473c862ffa..d883f29ed3698 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -41,6 +41,7 @@ if is_torch_available(): from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1 + from transformers.trainer import FSDP_MODEL_NAME else: is_torch_greater_or_equal_than_2_1 = False @@ -211,6 +212,19 @@ def test_training_and_can_resume_normally(self, state_dict_type): # resume from ckpt checkpoint = os.path.join(output_dir, "checkpoint-115") resume_args = args + f"--resume_from_checkpoint {checkpoint}".split() + + is_fsdp_ckpt = os.path.isdir(checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(checkpoint) + if os.path.isdir(os.path.join(checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin")) + ) + self.assertTrue(is_fsdp_ckpt) + logs_resume = self.run_cmd_and_get_logs( use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir )