Skip to content
36 changes: 12 additions & 24 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2533,7 +2533,6 @@ def _inner_training_loop(
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None

# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
Expand Down Expand Up @@ -2594,18 +2593,18 @@ def _inner_training_loop(
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

step = -1
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True

step = -1
# Handle resumption from checkpoint
if epoch == epochs_trained and resume_from_checkpoint is not None:
if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
step = steps_trained_in_current_epoch - 1
rng_to_sync = True
elif steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)

epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = steps_in_epoch % args.gradient_accumulation_steps
Expand Down Expand Up @@ -2658,22 +2657,11 @@ def _inner_training_loop(

input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()

if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False

# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

Expand Down Expand Up @@ -2765,7 +2753,7 @@ def _inner_training_loop(

model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss,
Expand Down
109 changes: 109 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5158,6 +5158,115 @@ def test_trainer_works_without_model_config(self):
)
trainer.train()

@require_safetensors
def test_resume_from_interrupted_training(self):
"""
Tests resuming training from a checkpoint after a simulated interruption.
"""

# --- Helper classes and functions defined locally for this test ---
class DummyModel(nn.Module):
def __init__(self, input_dim=10, num_labels=2):
super().__init__()
self.linear = nn.Linear(input_dim, num_labels)

def forward(self, input_ids=None, attention_mask=None, labels=None):
logits = self.linear(input_ids.float())
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {"loss": loss, "logits": logits}

class DummyDictDataset(torch.utils.data.Dataset):
def __init__(self, input_ids, attention_mask, labels):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels

def __len__(self):
return len(self.input_ids)

def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.labels[idx],
}

def create_dummy_dataset():
"""Creates a dummy dataset for this specific test."""
num_samples = 13
input_dim = 10
dummy_input_ids = torch.rand(num_samples, input_dim)
dummy_attention_mask = torch.ones(num_samples, input_dim)
dummy_labels = torch.randint(0, 2, (num_samples,))
return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels)

# 1. Set up a dummy model and dataset
model = DummyModel(input_dim=10, num_labels=2)
dummy_dataset = create_dummy_dataset()

# 2. First training phase (simulating an interruption)
output_dir_initial = self.get_auto_remove_tmp_dir()
training_args_initial = TrainingArguments(
output_dir=output_dir_initial,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=3,
save_strategy="steps",
save_steps=1, # Save at every step
report_to=[], # Disable wandb/tensorboard and other loggers
max_steps=2, # Stop after step 2 to simulate interruption
)

trainer_initial = Trainer(
model=model,
args=training_args_initial,
train_dataset=dummy_dataset,
)
trainer_initial.train()

# 3. Verify that a checkpoint was created before the "interruption"
checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2")
self.assertTrue(os.path.exists(checkpoint_path), f"Checkpoint not found at {checkpoint_path}")

# 4. Second training phase (resuming from the checkpoint)
output_dir_resumed = self.get_auto_remove_tmp_dir()
# Note: total steps for one epoch is ceil(13 / (2*3)) = 3.
# We stopped at step 2, so the resumed training should run for 1 more step.
training_args_resumed = TrainingArguments(
output_dir=output_dir_resumed,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=3,
save_strategy="steps",
save_steps=1,
report_to=[],
)

trainer_resumed = Trainer(
model=model,
args=training_args_resumed,
train_dataset=dummy_dataset,
)
# Resume from the interrupted checkpoint and finish the remaining training
trainer_resumed.train(resume_from_checkpoint=checkpoint_path)

# 5. Assertions: Check if the training completed and the final model was saved
# The training should have completed step 3.
# Total steps per epoch = ceil(13 samples / (2 batch_size * 3 grad_accum)) = 3
self.assertEqual(trainer_resumed.state.global_step, 3)

# Check that a checkpoint for the final step exists.
final_checkpoint_path = os.path.join(output_dir_resumed, "checkpoint-3")
self.assertTrue(os.path.exists(final_checkpoint_path))

# Check if the model weights file exists in the final checkpoint directory.
# Trainer saves non-PreTrainedModel models as `model.safetensors` by default if safetensors is available.
final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME)
self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!")


@require_torch
@is_staging_test
Expand Down