Skip to content

Commit

Permalink
fix bugs with trainer (#24134)
Browse files Browse the repository at this point in the history
* fix the deepspeed test failures

* apex fix

* FSDP save ckpt fix

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
pacman100 and sgugger committed Jun 9, 2023
1 parent 50ed793 commit a272e41
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,16 @@ def _inner_training_loop(

# prepare using `accelerator` prepare
if use_accelerator_prepare:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)

if self.is_fsdp_enabled:
self.model = model
Expand Down Expand Up @@ -2841,6 +2850,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
or self.is_fsdp_enabled
):
if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
Expand Down

0 comments on commit a272e41

Please sign in to comment.