diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 49e14ce56574..97de1d6d5397 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2533,7 +2533,6 @@ def _inner_training_loop( start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 - steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( @@ -2594,18 +2593,18 @@ def _inner_training_loop( ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) - if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - + step = -1 rng_to_sync = False - steps_skipped = 0 - if steps_trained_in_current_epoch > 0: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) - steps_skipped = steps_trained_in_current_epoch - steps_trained_in_current_epoch = 0 - rng_to_sync = True - step = -1 + # Handle resumption from checkpoint + if epoch == epochs_trained and resume_from_checkpoint is not None: + if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + step = steps_trained_in_current_epoch - 1 + rng_to_sync = True + elif steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches remainder = steps_in_epoch % args.gradient_accumulation_steps @@ -2658,22 +2657,11 @@ def _inner_training_loop( input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() + if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(1) - if steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - continue - elif steps_trained_progress_bar is not None: - steps_trained_progress_bar.close() - steps_trained_progress_bar = None - if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) @@ -2765,7 +2753,7 @@ def _inner_training_loop( model.zero_grad() self.state.global_step += 1 - self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate( tr_loss, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4d011033186a..47e1004df9b6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5158,6 +5158,115 @@ def test_trainer_works_without_model_config(self): ) trainer.train() + @require_safetensors + def test_resume_from_interrupted_training(self): + """ + Tests resuming training from a checkpoint after a simulated interruption. + """ + + # --- Helper classes and functions defined locally for this test --- + class DummyModel(nn.Module): + def __init__(self, input_dim=10, num_labels=2): + super().__init__() + self.linear = nn.Linear(input_dim, num_labels) + + def forward(self, input_ids=None, attention_mask=None, labels=None): + logits = self.linear(input_ids.float()) + loss = None + if labels is not None: + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(logits, labels) + return {"loss": loss, "logits": logits} + + class DummyDictDataset(torch.utils.data.Dataset): + def __init__(self, input_ids, attention_mask, labels): + self.input_ids = input_ids + self.attention_mask = attention_mask + self.labels = labels + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.labels[idx], + } + + def create_dummy_dataset(): + """Creates a dummy dataset for this specific test.""" + num_samples = 13 + input_dim = 10 + dummy_input_ids = torch.rand(num_samples, input_dim) + dummy_attention_mask = torch.ones(num_samples, input_dim) + dummy_labels = torch.randint(0, 2, (num_samples,)) + return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) + + # 1. Set up a dummy model and dataset + model = DummyModel(input_dim=10, num_labels=2) + dummy_dataset = create_dummy_dataset() + + # 2. First training phase (simulating an interruption) + output_dir_initial = self.get_auto_remove_tmp_dir() + training_args_initial = TrainingArguments( + output_dir=output_dir_initial, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, # Save at every step + report_to=[], # Disable wandb/tensorboard and other loggers + max_steps=2, # Stop after step 2 to simulate interruption + ) + + trainer_initial = Trainer( + model=model, + args=training_args_initial, + train_dataset=dummy_dataset, + ) + trainer_initial.train() + + # 3. Verify that a checkpoint was created before the "interruption" + checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") + self.assertTrue(os.path.exists(checkpoint_path), f"Checkpoint not found at {checkpoint_path}") + + # 4. Second training phase (resuming from the checkpoint) + output_dir_resumed = self.get_auto_remove_tmp_dir() + # Note: total steps for one epoch is ceil(13 / (2*3)) = 3. + # We stopped at step 2, so the resumed training should run for 1 more step. + training_args_resumed = TrainingArguments( + output_dir=output_dir_resumed, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, + report_to=[], + ) + + trainer_resumed = Trainer( + model=model, + args=training_args_resumed, + train_dataset=dummy_dataset, + ) + # Resume from the interrupted checkpoint and finish the remaining training + trainer_resumed.train(resume_from_checkpoint=checkpoint_path) + + # 5. Assertions: Check if the training completed and the final model was saved + # The training should have completed step 3. + # Total steps per epoch = ceil(13 samples / (2 batch_size * 3 grad_accum)) = 3 + self.assertEqual(trainer_resumed.state.global_step, 3) + + # Check that a checkpoint for the final step exists. + final_checkpoint_path = os.path.join(output_dir_resumed, "checkpoint-3") + self.assertTrue(os.path.exists(final_checkpoint_path)) + + # Check if the model weights file exists in the final checkpoint directory. + # Trainer saves non-PreTrainedModel models as `model.safetensors` by default if safetensors is available. + final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME) + self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!") + @require_torch @is_staging_test