Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bugs with trainer #24134

Merged
merged 4 commits into from
Jun 9, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,16 @@ def _inner_training_loop(

# prepare using `accelerator` prepare
if use_accelerator_prepare:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is sopecified in DeepSpeed config.
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)

if self.is_fsdp_enabled:
self.model = model
Expand Down Expand Up @@ -2841,6 +2850,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
or self.is_fsdp_enabled
):
if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
Expand Down
Loading