Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] fix double wrapping + test #10583

Merged
merged 2 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,10 @@ def _wrap_model(self, model, training=True):
if self.deepspeed:
return self.deepspeed

# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
return model

# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,19 @@ def test_gradient_accumulation(self):
trainer.train()
self.check_trained_model(trainer.model)

@require_torch_multi_gpu
def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once
# since wrapping primarily happens on multi-gpu setup we want multiple gpus to test for
# example DataParallel(DataParallel(model))

trainer = get_regression_trainer()
trainer.train()
model_wrapped_before = trainer.model_wrapped
trainer.train()
model_wrapped_after = trainer.model_wrapped
self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")

def test_can_resume_training(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
Expand Down