Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions examples/tests/trainer/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def require_apex(test_case):


class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
output_dir = self.run_trainer(
eval_steps=1,
max_len=12,
Expand All @@ -83,9 +83,9 @@ def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, p
if predict_with_generate:
assert "eval_bleu" in first_step_stats

last_step_stats = eval_metrics[-1]
assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
last_step_stats = eval_metrics[-1]
assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"

@require_torch_non_multi_gpu
def test_run_seq2seq_no_dist(self):
Expand Down Expand Up @@ -116,14 +116,12 @@ def test_run_seq2seq_sharded_ddp_fp16(self):
# test --sharded_ddp zero_dp_2 w/o --fp16
@require_torch_multi_gpu
@require_fairscale
@unittest.skip("XXX: Fixme: hanging")
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)

# test --sharded_ddp zero_dp_2 w/ --fp16
@require_torch_multi_gpu
@require_fairscale
@unittest.skip("XXX: Fixme: hanging")
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
Expand Down Expand Up @@ -206,8 +204,8 @@ def run_trainer(
--warmup_steps 8
--evaluation_strategy steps
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
--save_steps {str(eval_steps)}
--group_by_length
--label_smoothing_factor 0.1
--adafactor
Expand Down
22 changes: 14 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,11 +1497,14 @@ def save_model(self, output_dir: Optional[str] = None):
"""
if is_torch_tpu_available():
self._save_tpu(output_dir)
else:
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero():
self._save(output_dir)
if self.args.local_rank != -1:
dist.barrier()
self._save(output_dir, state_dict=state_dict)
elif self.is_world_process_zero():
self._save(output_dir)

def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
Expand Down Expand Up @@ -1531,7 +1534,7 @@ def _save_tpu(self, output_dir: Optional[str] = None):
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)

def _save(self, output_dir: Optional[str] = None):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
Expand All @@ -1540,13 +1543,16 @@ def _save(self, output_dir: Optional[str] = None):
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict())
if state_dict is None:
state_dict = self.model.state_dict()
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disab
config = RegressionModelConfig(a=a, b=b)
model = RegressionPreTrainedModel(config)

args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs)
args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, report_to=[], **kwargs)
return Trainer(
model,
args,
Expand Down