Skip to content

Commit

Permalink
[deepspeed] saving checkpoint fallback when fp16 weights aren't saved (
Browse files Browse the repository at this point in the history
…#14948)

* [deepspeed] saving checkpoint fallback when fp16 weights aren't saved

* Bump required deepspeed version to match usage when saving checkpoints

* update version

Co-authored-by: Mihai Balint <balint.mihai@gmail.com>
  • Loading branch information
stas00 and MihaiBalint authored Jan 28, 2022
1 parent d25e25e commit 297602c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>=0.5.7",
"deepspeed>=0.5.9",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.5.7",
"deepspeed": "deepspeed>=0.5.9",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,7 +2054,12 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
# now save the real model if stage3_gather_fp16_weights_on_model_save=True
# if false it will not be saved.
# This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
if not self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME):
logger.warning(
"deepspeed.save_fp16_model didn't save the model, since stage3_gather_fp16_weights_on_model_save=false. "
"Saving the full checkpoint instead, use zero_to_fp32.py to recover weights"
)
self.deepspeed.save_checkpoint(output_dir)

elif self.args.should_save:
self._save(output_dir)
Expand Down

0 comments on commit 297602c

Please sign in to comment.