From fe7edb36fcb3acdcbeb264011412553720358cc1 Mon Sep 17 00:00:00 2001 From: Rangehow Date: Thu, 21 Aug 2025 19:55:53 +0800 Subject: [PATCH 1/7] fix(trainer): ensure final checkpoint is saved when resuming training --- src/transformers/trainer.py | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 23777e886670..3ff2961b02ea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2584,18 +2584,19 @@ def _inner_training_loop( ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) - if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - + step = -1 rng_to_sync = False - steps_skipped = 0 - if steps_trained_in_current_epoch > 0: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) - steps_skipped = steps_trained_in_current_epoch - steps_trained_in_current_epoch = 0 - rng_to_sync = True - step = -1 + # Handle resumption from checkpoint + if epoch == epochs_trained and resume_from_checkpoint is not None: + if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + step = steps_trained_in_current_epoch - 1 + rng_to_sync = True + elif steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches remainder = steps_in_epoch % args.gradient_accumulation_steps @@ -2630,21 +2631,12 @@ def _inner_training_loop( input_tokens = inputs[main_input_name].numel() input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() + if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(1) - if steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - continue - elif steps_trained_progress_bar is not None: - steps_trained_progress_bar.close() - steps_trained_progress_bar = None + if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) @@ -2737,7 +2729,7 @@ def _inner_training_loop( model.zero_grad() self.state.global_step += 1 - self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate( tr_loss, From ea0ad02c85bebaa40f2405c5ac64ed3fcd0b748b Mon Sep 17 00:00:00 2001 From: Rangehow Date: Thu, 21 Aug 2025 20:01:22 +0800 Subject: [PATCH 2/7] add test --- tests/trainer/test_trainer_resume.py | 120 +++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/trainer/test_trainer_resume.py diff --git a/tests/trainer/test_trainer_resume.py b/tests/trainer/test_trainer_resume.py new file mode 100644 index 000000000000..642ce734794a --- /dev/null +++ b/tests/trainer/test_trainer_resume.py @@ -0,0 +1,120 @@ +import os +import shutil +import torch +from torch.utils.data import TensorDataset, Dataset +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + Trainer, + TrainingArguments, +) +import torch.nn as nn + + +class DummyModel(nn.Module): + def __init__(self, input_dim=10, num_labels=2): + super().__init__() + self.linear = nn.Linear(input_dim, num_labels) + + def forward(self, input_ids=None, attention_mask=None, labels=None): + logits = self.linear(input_ids.float()) + loss = None + if labels is not None: + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(logits, labels) + return {"loss": loss, "logits": logits} + +class DummyDictDataset(Dataset): + def __init__(self, input_ids, attention_mask, labels): + self.input_ids = input_ids + self.attention_mask = attention_mask + self.labels = labels + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.labels[idx], + } + +def create_dummy_dataset(): + """Creates a dummy dataset for testing.""" + num_samples = 13 + input_dim = 10 + dummy_input_ids = torch.rand(num_samples, input_dim) + dummy_attention_mask = torch.ones(num_samples, input_dim) + dummy_labels = torch.randint(0, 2, (num_samples,)) + return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) + +def test_resume_with_original_trainer(): + """Tests the original transformers Trainer.""" + print("Testing the original transformers Trainer...") + + # 1. Set up a dummy model + model = DummyModel(input_dim=10, num_labels=2) + dummy_dataset = create_dummy_dataset() + + # 3. First training (simulate interruption) + output_dir_initial = "./test_original_trainer_initial" + training_args_initial = TrainingArguments( + output_dir=output_dir_initial, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, # Save at every step + report_to=[], # Disable wandb/tensorboard and other loggers + max_steps=2, # Stop after step 2 to simulate interruption + ) + + trainer_initial = Trainer( + model=model, + args=training_args_initial, + train_dataset=dummy_dataset, + ) + trainer_initial.train() + + # Make sure we have a checkpoint before interruption + checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") + assert os.path.exists(checkpoint_path) + + print("Second phase") + # 4. Resume training from checkpoint + output_dir_resumed = "./test_original_trainer_resumed" + training_args_resumed = TrainingArguments( + output_dir=output_dir_resumed, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, # Keep the same save strategy + ) + + trainer_resumed = Trainer( + model=model, + args=training_args_resumed, + train_dataset=dummy_dataset, + ) + # Resume from the interrupted checkpoint and finish the remaining training + trainer_resumed.train(resume_from_checkpoint=checkpoint_path) + + # 5. Assertion: Check if the final model has been saved + final_model_path = os.path.join(output_dir_resumed,'checkpoint-3', "model.safetensors") + try: + assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!" + print("✓ Original Trainer: Final model has been saved.") + except AssertionError as e: + print(f"✗ Original Trainer: {e}") + + + # Clean up test directories + shutil.rmtree(output_dir_initial) + shutil.rmtree(output_dir_resumed) + + +# Run all tests +if __name__ == "__main__": + test_resume_with_original_trainer() \ No newline at end of file From 37ff9888a5faa033e5579a42e611f8ad7cdbaed7 Mon Sep 17 00:00:00 2001 From: Rangehow Date: Fri, 22 Aug 2025 10:31:27 +0800 Subject: [PATCH 3/7] make style && slight fix of test --- src/transformers/trainer.py | 6 +- tests/trainer/test_trainer_resume.py | 142 ++++++++++++++------------- 2 files changed, 73 insertions(+), 75 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3ff2961b02ea..de83c818d22d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2523,7 +2523,6 @@ def _inner_training_loop( start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 - steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( @@ -2596,7 +2595,6 @@ def _inner_training_loop( elif steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) - epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches remainder = steps_in_epoch % args.gradient_accumulation_steps @@ -2631,13 +2629,11 @@ def _inner_training_loop( input_tokens = inputs[main_input_name].numel() input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - + if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False - - if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) diff --git a/tests/trainer/test_trainer_resume.py b/tests/trainer/test_trainer_resume.py index 642ce734794a..79d81fce9824 100644 --- a/tests/trainer/test_trainer_resume.py +++ b/tests/trainer/test_trainer_resume.py @@ -1,14 +1,15 @@ import os -import shutil + import torch -from torch.utils.data import TensorDataset, Dataset +import torch.nn as nn +from torch.utils.data import Dataset + from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, Trainer, TrainingArguments, ) -import torch.nn as nn + +from transformers.testing_utils import TestCasePlus class DummyModel(nn.Module): @@ -24,6 +25,7 @@ def forward(self, input_ids=None, attention_mask=None, labels=None): loss = loss_fn(logits, labels) return {"loss": loss, "logits": logits} + class DummyDictDataset(Dataset): def __init__(self, input_ids, attention_mask, labels): self.input_ids = input_ids @@ -40,6 +42,7 @@ def __getitem__(self, idx): "labels": self.labels[idx], } + def create_dummy_dataset(): """Creates a dummy dataset for testing.""" num_samples = 13 @@ -49,72 +52,71 @@ def create_dummy_dataset(): dummy_labels = torch.randint(0, 2, (num_samples,)) return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) -def test_resume_with_original_trainer(): - """Tests the original transformers Trainer.""" - print("Testing the original transformers Trainer...") - - # 1. Set up a dummy model - model = DummyModel(input_dim=10, num_labels=2) - dummy_dataset = create_dummy_dataset() - - # 3. First training (simulate interruption) - output_dir_initial = "./test_original_trainer_initial" - training_args_initial = TrainingArguments( - output_dir=output_dir_initial, - num_train_epochs=1, - per_device_train_batch_size=2, - gradient_accumulation_steps=3, - save_strategy="steps", - save_steps=1, # Save at every step - report_to=[], # Disable wandb/tensorboard and other loggers - max_steps=2, # Stop after step 2 to simulate interruption - ) - - trainer_initial = Trainer( - model=model, - args=training_args_initial, - train_dataset=dummy_dataset, - ) - trainer_initial.train() - - # Make sure we have a checkpoint before interruption - checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") - assert os.path.exists(checkpoint_path) - - print("Second phase") - # 4. Resume training from checkpoint - output_dir_resumed = "./test_original_trainer_resumed" - training_args_resumed = TrainingArguments( - output_dir=output_dir_resumed, - num_train_epochs=1, - per_device_train_batch_size=2, - gradient_accumulation_steps=3, - save_strategy="steps", - save_steps=1, # Keep the same save strategy - ) - - trainer_resumed = Trainer( - model=model, - args=training_args_resumed, - train_dataset=dummy_dataset, - ) - # Resume from the interrupted checkpoint and finish the remaining training - trainer_resumed.train(resume_from_checkpoint=checkpoint_path) - - # 5. Assertion: Check if the final model has been saved - final_model_path = os.path.join(output_dir_resumed,'checkpoint-3', "model.safetensors") - try: - assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!" - print("✓ Original Trainer: Final model has been saved.") - except AssertionError as e: - print(f"✗ Original Trainer: {e}") - - - # Clean up test directories - shutil.rmtree(output_dir_initial) - shutil.rmtree(output_dir_resumed) + +class TestTrainerResume(TestCasePlus): + def test_resume_with_original_trainer(self): + """Tests the original transformers Trainer.""" + print("Testing the original transformers Trainer...") + + # 1. Set up a dummy model + model = DummyModel(input_dim=10, num_labels=2) + dummy_dataset = create_dummy_dataset() + + # 3. First training (simulate interruption) + output_dir_initial = self.get_auto_remove_tmp_dir() + training_args_initial = TrainingArguments( + output_dir=output_dir_initial, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, # Save at every step + report_to=[], # Disable wandb/tensorboard and other loggers + max_steps=2, # Stop after step 2 to simulate interruption + ) + + trainer_initial = Trainer( + model=model, + args=training_args_initial, + train_dataset=dummy_dataset, + ) + trainer_initial.train() + + # Make sure we have a checkpoint before interruption + checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") + assert os.path.exists(checkpoint_path) + + print("Second phase") + # 4. Resume training from checkpoint + output_dir_resumed = self.get_auto_remove_tmp_dir() + training_args_resumed = TrainingArguments( + output_dir=output_dir_resumed, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, # Keep the same save strategy + ) + + trainer_resumed = Trainer( + model=model, + args=training_args_resumed, + train_dataset=dummy_dataset, + ) + # Resume from the interrupted checkpoint and finish the remaining training + trainer_resumed.train(resume_from_checkpoint=checkpoint_path) + + # 5. Assertion: Check if the final model has been saved + final_model_path = os.path.join(output_dir_resumed, "checkpoint-3", "model.safetensors") + try: + assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!" + print("✓ Original Trainer: Final model has been saved.") + except AssertionError as e: + print(f"✗ Original Trainer: {e}") # Run all tests if __name__ == "__main__": - test_resume_with_original_trainer() \ No newline at end of file + import unittest + + unittest.main() From 38a216fd1f0bca5cf166b49e35fc0dc44cb5668d Mon Sep 17 00:00:00 2001 From: Rangehow Date: Fri, 22 Aug 2025 10:37:15 +0800 Subject: [PATCH 4/7] make style again --- tests/trainer/test_trainer_resume.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer_resume.py b/tests/trainer/test_trainer_resume.py index 79d81fce9824..cfb34f5f8862 100644 --- a/tests/trainer/test_trainer_resume.py +++ b/tests/trainer/test_trainer_resume.py @@ -8,7 +8,6 @@ Trainer, TrainingArguments, ) - from transformers.testing_utils import TestCasePlus From aa8a637fa158f716836866800b1debd830f05d21 Mon Sep 17 00:00:00 2001 From: rangehow Date: Tue, 26 Aug 2025 22:59:10 +0800 Subject: [PATCH 5/7] move test code to test_trainer --- tests/trainer/test_trainer.py | 109 ++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 683c76032dd0..a7d39518b56a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5064,6 +5064,115 @@ def test_trainer_works_without_model_config(self): trainer.train() + @require_safetensors + def test_resume_from_interrupted_training(self): + """ + Tests resuming training from a checkpoint after a simulated interruption. + """ + + # --- Helper classes and functions defined locally for this test --- + class DummyModel(nn.Module): + def __init__(self, input_dim=10, num_labels=2): + super().__init__() + self.linear = nn.Linear(input_dim, num_labels) + + def forward(self, input_ids=None, attention_mask=None, labels=None): + logits = self.linear(input_ids.float()) + loss = None + if labels is not None: + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(logits, labels) + return {"loss": loss, "logits": logits} + + class DummyDictDataset(torch.utils.data.Dataset): + def __init__(self, input_ids, attention_mask, labels): + self.input_ids = input_ids + self.attention_mask = attention_mask + self.labels = labels + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.labels[idx], + } + + def create_dummy_dataset(): + """Creates a dummy dataset for this specific test.""" + num_samples = 13 + input_dim = 10 + dummy_input_ids = torch.rand(num_samples, input_dim) + dummy_attention_mask = torch.ones(num_samples, input_dim) + dummy_labels = torch.randint(0, 2, (num_samples,)) + return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) + + # 1. Set up a dummy model and dataset + model = DummyModel(input_dim=10, num_labels=2) + dummy_dataset = create_dummy_dataset() + + # 2. First training phase (simulating an interruption) + output_dir_initial = self.get_auto_remove_tmp_dir() + training_args_initial = TrainingArguments( + output_dir=output_dir_initial, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, # Save at every step + report_to=[], # Disable wandb/tensorboard and other loggers + max_steps=2, # Stop after step 2 to simulate interruption + ) + + trainer_initial = Trainer( + model=model, + args=training_args_initial, + train_dataset=dummy_dataset, + ) + trainer_initial.train() + + # 3. Verify that a checkpoint was created before the "interruption" + checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") + self.assertTrue(os.path.exists(checkpoint_path), f"Checkpoint not found at {checkpoint_path}") + + # 4. Second training phase (resuming from the checkpoint) + output_dir_resumed = self.get_auto_remove_tmp_dir() + # Note: total steps for one epoch is ceil(13 / (2*3)) = 3. + # We stopped at step 2, so the resumed training should run for 1 more step. + training_args_resumed = TrainingArguments( + output_dir=output_dir_resumed, + num_train_epochs=1, + per_device_train_batch_size=2, + gradient_accumulation_steps=3, + save_strategy="steps", + save_steps=1, + report_to=[], + ) + + trainer_resumed = Trainer( + model=model, + args=training_args_resumed, + train_dataset=dummy_dataset, + ) + # Resume from the interrupted checkpoint and finish the remaining training + trainer_resumed.train(resume_from_checkpoint=checkpoint_path) + + # 5. Assertions: Check if the training completed and the final model was saved + # The training should have completed step 3. + # Total steps per epoch = ceil(13 samples / (2 batch_size * 3 grad_accum)) = 3 + self.assertEqual(trainer_resumed.state.global_step, 3) + + # Check that a checkpoint for the final step exists. + final_checkpoint_path = os.path.join(output_dir_resumed, "checkpoint-3") + self.assertTrue(os.path.exists(final_checkpoint_path)) + + # Check if the model weights file exists in the final checkpoint directory. + # Trainer saves non-PreTrainedModel models as `model.safetensors` by default if safetensors is available. + final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME) + self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!") + @require_torch @is_staging_test class TrainerIntegrationWithHubTester(unittest.TestCase): From 8712ce12a8965194ede5082cee7869d4878ffad5 Mon Sep 17 00:00:00 2001 From: rangehow Date: Tue, 26 Aug 2025 23:01:16 +0800 Subject: [PATCH 6/7] remove outdated test file --- tests/trainer/test_trainer_resume.py | 121 --------------------------- 1 file changed, 121 deletions(-) delete mode 100644 tests/trainer/test_trainer_resume.py diff --git a/tests/trainer/test_trainer_resume.py b/tests/trainer/test_trainer_resume.py deleted file mode 100644 index cfb34f5f8862..000000000000 --- a/tests/trainer/test_trainer_resume.py +++ /dev/null @@ -1,121 +0,0 @@ -import os - -import torch -import torch.nn as nn -from torch.utils.data import Dataset - -from transformers import ( - Trainer, - TrainingArguments, -) -from transformers.testing_utils import TestCasePlus - - -class DummyModel(nn.Module): - def __init__(self, input_dim=10, num_labels=2): - super().__init__() - self.linear = nn.Linear(input_dim, num_labels) - - def forward(self, input_ids=None, attention_mask=None, labels=None): - logits = self.linear(input_ids.float()) - loss = None - if labels is not None: - loss_fn = nn.CrossEntropyLoss() - loss = loss_fn(logits, labels) - return {"loss": loss, "logits": logits} - - -class DummyDictDataset(Dataset): - def __init__(self, input_ids, attention_mask, labels): - self.input_ids = input_ids - self.attention_mask = attention_mask - self.labels = labels - - def __len__(self): - return len(self.input_ids) - - def __getitem__(self, idx): - return { - "input_ids": self.input_ids[idx], - "attention_mask": self.attention_mask[idx], - "labels": self.labels[idx], - } - - -def create_dummy_dataset(): - """Creates a dummy dataset for testing.""" - num_samples = 13 - input_dim = 10 - dummy_input_ids = torch.rand(num_samples, input_dim) - dummy_attention_mask = torch.ones(num_samples, input_dim) - dummy_labels = torch.randint(0, 2, (num_samples,)) - return DummyDictDataset(dummy_input_ids, dummy_attention_mask, dummy_labels) - - -class TestTrainerResume(TestCasePlus): - def test_resume_with_original_trainer(self): - """Tests the original transformers Trainer.""" - print("Testing the original transformers Trainer...") - - # 1. Set up a dummy model - model = DummyModel(input_dim=10, num_labels=2) - dummy_dataset = create_dummy_dataset() - - # 3. First training (simulate interruption) - output_dir_initial = self.get_auto_remove_tmp_dir() - training_args_initial = TrainingArguments( - output_dir=output_dir_initial, - num_train_epochs=1, - per_device_train_batch_size=2, - gradient_accumulation_steps=3, - save_strategy="steps", - save_steps=1, # Save at every step - report_to=[], # Disable wandb/tensorboard and other loggers - max_steps=2, # Stop after step 2 to simulate interruption - ) - - trainer_initial = Trainer( - model=model, - args=training_args_initial, - train_dataset=dummy_dataset, - ) - trainer_initial.train() - - # Make sure we have a checkpoint before interruption - checkpoint_path = os.path.join(output_dir_initial, "checkpoint-2") - assert os.path.exists(checkpoint_path) - - print("Second phase") - # 4. Resume training from checkpoint - output_dir_resumed = self.get_auto_remove_tmp_dir() - training_args_resumed = TrainingArguments( - output_dir=output_dir_resumed, - num_train_epochs=1, - per_device_train_batch_size=2, - gradient_accumulation_steps=3, - save_strategy="steps", - save_steps=1, # Keep the same save strategy - ) - - trainer_resumed = Trainer( - model=model, - args=training_args_resumed, - train_dataset=dummy_dataset, - ) - # Resume from the interrupted checkpoint and finish the remaining training - trainer_resumed.train(resume_from_checkpoint=checkpoint_path) - - # 5. Assertion: Check if the final model has been saved - final_model_path = os.path.join(output_dir_resumed, "checkpoint-3", "model.safetensors") - try: - assert os.path.exists(final_model_path), "Original Trainer: Final model checkpoint was not saved!" - print("✓ Original Trainer: Final model has been saved.") - except AssertionError as e: - print(f"✗ Original Trainer: {e}") - - -# Run all tests -if __name__ == "__main__": - import unittest - - unittest.main() From 0967dc82c3735fd02e9f7d819af4d9d6bc0940cf Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 26 Aug 2025 15:04:51 +0000 Subject: [PATCH 7/7] Apply style fixes --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a7d39518b56a..6ea4dabc2ef6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -5063,7 +5063,6 @@ def test_trainer_works_without_model_config(self): ) trainer.train() - @require_safetensors def test_resume_from_interrupted_training(self): """ @@ -5173,6 +5172,7 @@ def create_dummy_dataset(): final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME) self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!") + @require_torch @is_staging_test class TrainerIntegrationWithHubTester(unittest.TestCase):